single_dataset.py 1.5 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6
import cv2
import paddle
from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset

from .builder import DATASETS
L
LielinJiang 已提交
7
from .transforms.builder import build_transforms
L
LielinJiang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22


@DATASETS.register()
class SingleDataset(BaseDataset):
    """
    """
    def __init__(self, cfg):
        """Initialize this dataset class.

        Args:
            cfg (dict) -- stores all the experiment flags
        """
        BaseDataset.__init__(self, cfg)
        self.A_paths = sorted(make_dataset(cfg.dataroot, cfg.max_dataset_size))
        input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
L
LielinJiang 已提交
23
        self.transform = build_transforms(self.cfg.transforms)
L
LielinJiang 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        A_path = self.A_paths[index]
        A_img = cv2.imread(A_path)
        A = self.transform(A_img)
L
LielinJiang 已提交
38

L
LielinJiang 已提交
39
        return {'A': A, 'A_paths': A_path}
L
LielinJiang 已提交
40 41 42 43 44 45 46 47 48 49 50 51

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.A_paths)

    def get_path_by_indexs(self, indexs):
        if isinstance(indexs, paddle.Variable):
            indexs = indexs.numpy()
        current_paths = []
        for index in indexs:
            current_paths.append(self.A_paths[index])
        return current_paths