diff --git a/configs/cyclegan_cityscapes.yaml b/configs/cyclegan_cityscapes.yaml index f74d9e3bdd35e3a521b3056a5b268f67bba2e406..79b941d4bb46f3cbe939096e8b00872bbcf10e6c 100644 --- a/configs/cyclegan_cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -36,16 +36,17 @@ dataset: output_nc: 3 serial_batches: False pool_size: 50 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + - name: RandomCrop + output_size: [256, 256] + - name: RandomHorizontalFlip + prob: 0.5 + - name: Permute + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] test: name: SingleDataset dataroot: data/cityscapes/testB @@ -55,17 +56,13 @@ dataset: output_nc: 3 serial_batches: False pool_size: 50 - transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) - + transforms: + - name: Resize + size: [256, 256] + - name: Permute + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] optimizer: name: Adam diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index 0e845bd5183428f7c166bae300f74757406c07f5..01bb31a7753afafe6d0d47a81d437fe455704625 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -35,16 +35,17 @@ dataset: output_nc: 3 serial_batches: False pool_size: 50 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + - name: RandomCrop + output_size: [256, 256] + - name: RandomHorizontalFlip + prob: 0.5 + - name: Permute + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] test: name: SingleDataset dataroot: data/horse2zebra/testA @@ -55,15 +56,13 @@ dataset: serial_batches: False pool_size: 50 transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transform: + - name: Resize + size: [256, 256] + - name: Permute + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] optimizer: name: Adam diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 5919ff2e5a5c2c267a9204d117dc7aba5fb245a7..fabb4b404759f1e0b5cc573d69f5ad5a112748e4 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -33,16 +33,22 @@ dataset: output_nc: 3 serial_batches: False pool_size: 0 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + keys: [image, image] + - name: PairedRandomCrop + output_size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset dataroot: data/cityscapes/ @@ -53,16 +59,17 @@ dataset: output_nc: 3 serial_batches: True pool_size: 50 - transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [256, 256] + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] + optimizer: name: Adam diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index 20f494c6fb13690254dd2d047df8c8970615ebff..246599f12d1d921160da3a717cc369b0709027b5 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -32,16 +32,22 @@ dataset: output_nc: 3 serial_batches: False pool_size: 0 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + keys: [image, image] + - name: PairedRandomCrop + output_size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset dataroot: data/cityscapes/ @@ -52,16 +58,16 @@ dataset: output_nc: 3 serial_batches: True pool_size: 50 - transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [256, 256] + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] optimizer: name: Adam diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 31b5f145dccdfd75bbdcd14c3fa896676d729037..f8c83ff84d3f32ea0a79ab536aab64f5c41a9b8e 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -32,16 +32,22 @@ dataset: output_nc: 3 serial_batches: False pool_size: 0 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + keys: [image, image] + - name: PairedRandomCrop + output_size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset dataroot: data/facades/ @@ -52,16 +58,16 @@ dataset: output_nc: 3 serial_batches: True pool_size: 50 - transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [256, 256] + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] optimizer: name: Adam diff --git a/ppgan/datasets/paired_dataset.py b/ppgan/datasets/paired_dataset.py index 368f8371178ab771d3139103992a97abc3ee0fe8..d15d43b131b96f4892dd6be9ec22b7faebd4607c 100644 --- a/ppgan/datasets/paired_dataset.py +++ b/ppgan/datasets/paired_dataset.py @@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform from .image_folder import make_dataset from .builder import DATASETS +from .transforms.builder import build_transforms @DATASETS.register() class PairedDataset(BaseDataset): """A dataset class for paired image dataset. """ - def __init__(self, cfg): """Initialize this dataset class. @@ -19,11 +19,14 @@ class PairedDataset(BaseDataset): cfg (dict) -- stores all the experiment flags """ BaseDataset.__init__(self, cfg) - self.dir_AB = os.path.join(cfg.dataroot, cfg.phase) # get the image directory - self.AB_paths = sorted(make_dataset(self.dir_AB, cfg.max_dataset_size)) # get image paths - assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image + self.dir_AB = os.path.join(cfg.dataroot, + cfg.phase) # get the image directory + self.AB_paths = sorted(make_dataset( + self.dir_AB, cfg.max_dataset_size)) # get image paths + # assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc + self.transforms = build_transforms(cfg.transforms) def __getitem__(self, index): """Return a data point and its metadata information. @@ -49,27 +52,20 @@ class PairedDataset(BaseDataset): A = AB[:h, :w2, :] B = AB[:h, w2:, :] - # apply the same transform to both A and B # transform_params = get_params(self.opt, A.size) - transform_params = get_params(self.cfg.transform, (w2, h)) + # transform_params = get_params(self.cfg.transform, (w2, h)) - A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1)) - B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1)) + # A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1)) + # B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1)) - A = A_transform(A) - B = B_transform(B) + # A = A_transform(A) + # B = B_transform(B) + # A, B = self.transforms((A, B)) + A, B = self.transforms((A, B)) return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} def __len__(self): """Return the total number of images in the dataset.""" return len(self.AB_paths) - - def get_path_by_indexs(self, indexs): - if isinstance(indexs, paddle.Variable): - indexs = indexs.numpy() - current_paths = [] - for index in indexs: - current_paths.append(self.AB_paths[index]) - return current_paths diff --git a/ppgan/datasets/transforms/__init__.py b/ppgan/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7734481de315882ca29235ddb4b63aa8d4e7a58c --- /dev/null +++ b/ppgan/datasets/transforms/__init__.py @@ -0,0 +1 @@ +from .transforms import RandomCrop, Resize, RandomHorizontalFlip, PairedRandomCrop, PairedRandomHorizontalFlip, Normalize, Permute diff --git a/ppgan/datasets/transforms/builder.py b/ppgan/datasets/transforms/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..01e8f9c5b1da88807f10b4aeff76a17c5653e408 --- /dev/null +++ b/ppgan/datasets/transforms/builder.py @@ -0,0 +1,55 @@ +import copy +import traceback +import paddle +from ...utils.registry import Registry + +TRANSFORMS = Registry("TRANSFORMS") + + +class Compose(object): + """ + Composes several transforms together use for composing list of transforms + together for a dataset transform. + + Args: + transforms (list): List of transforms to compose. + + Returns: + A compose object which is callable, __call__ for this Compose + object will call each given :attr:`transforms` sequencely. + + """ + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, data): + for f in self.transforms: + try: + # multi-fileds in a sample + # if isinstance(data, Sequence): + # data = f(*data) + # # single field in a sample, call transform directly + # else: + data = f(data) + except Exception as e: + stack_info = traceback.format_exc() + print("fail to perform transform [{}] with error: " + "{} and stack:\n{}".format(f, e, str(stack_info))) + raise e + return data + + +def build_transform(cfg): + pass + + +def build_transforms(cfg): + transforms = [] + + for trans_cfg in cfg: + temp_trans_cfg = copy.deepcopy(trans_cfg) + name = temp_trans_cfg.pop('name') + transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg)) + + transforms = Compose(transforms) + return transforms diff --git a/ppgan/datasets/transforms/transforms.py b/ppgan/datasets/transforms/transforms.py index fa54da01bb23b50f71ea47b795cb67d595ce76ef..97287aab59d12fb1cc42016e7427b3d21435e39f 100644 --- a/ppgan/datasets/transforms/transforms.py +++ b/ppgan/datasets/transforms/transforms.py @@ -1,9 +1,101 @@ +import sys +import types import random +import numbers +import warnings +import traceback +import collections +import numpy as np +from paddle.utils import try_import +import paddle.vision.transforms.functional as F +import paddle.vision.transforms.transforms as T -class RandomCrop(object): +from .builder import TRANSFORMS - def __init__(self, output_size): +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + + +class Transform(): + def _set_attributes(self, args): + """ + Set attributes from the input list of parameters. + + Args: + args (list): list of parameters. + """ + if args: + for k, v in args.items(): + # print(k, v) + if k != "self" and not k.startswith("_"): + setattr(self, k, v) + + def apply_image(self, input): + raise NotImplementedError + + def __call__(self, inputs): + # print('debug:', type(inputs), type(inputs[0])) + if isinstance(inputs, tuple): + inputs = list(inputs) + if self.keys is not None: + for i, key in enumerate(self.keys): + if isinstance(inputs, dict): + inputs[key] = getattr(self, 'apply_' + key)(inputs[key]) + elif isinstance(inputs, (list, tuple)): + inputs[i] = getattr(self, 'apply_' + key)(inputs[i]) + else: + inputs = self.apply_image(inputs) + + if isinstance(inputs, list): + inputs = tuple(inputs) + + return inputs + + +@TRANSFORMS.register() +class Resize(Transform): + """Resize the input Image to the given size. + + Args: + size (int|list|tuple): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Interpolation mode of resize. Default: 1. + 0 : cv2.INTER_NEAREST + 1 : cv2.INTER_LINEAR + 2 : cv2.INTER_CUBIC + 3 : cv2.INTER_AREA + 4 : cv2.INTER_LANCZOS4 + 5 : cv2.INTER_LINEAR_EXACT + 7 : cv2.INTER_MAX + 8 : cv2.WARP_FILL_OUTLIERS + 16: cv2.WARP_INVERSE_MAP + + """ + def __init__(self, size, interpolation=1, keys=None): + super().__init__() + assert isinstance(size, int) or (isinstance(size, Iterable) + and len(size) == 2) + self._set_attributes(locals()) + if isinstance(self.size, Iterable): + self.size = tuple(size) + + def apply_image(self, img): + return F.resize(img, self.size, self.interpolation) + + +@TRANSFORMS.register() +class RandomCrop(Transform): + def __init__(self, output_size, keys=None): + super().__init__() + self._set_attributes(locals()) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: @@ -19,12 +111,171 @@ class RandomCrop(object): j = random.randint(0, w - tw) return i, j, th, tw - def __call__(self, img): + def apply_image(self, img): i, j, h, w = self._get_params(img) cropped_img = img[i:i + h, j:j + w] return cropped_img +@TRANSFORMS.register() +class PairedRandomCrop(RandomCrop): + def __init__(self, output_size, keys=None): + super().__init__(output_size, keys) + + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + self.output_size = output_size + + def apply_image(self, img, crop_prams=None): + if crop_prams is not None: + i, j, h, w = crop_prams + else: + i, j, h, w = self._get_params(img) + cropped_img = img[i:i + h, j:j + w] + return cropped_img + + def __call__(self, inputs): + if isinstance(inputs, tuple): + inputs = list(inputs) + if self.keys is not None: + if isinstance(inputs, dict): + crop_params = self._get_params(inputs[self.keys[0]]) + elif isinstance(inputs, (list, tuple)): + crop_params = self._get_params(inputs[0]) + + for i, key in enumerate(self.keys): + if isinstance(inputs, dict): + inputs[key] = getattr(self, 'apply_' + key)(inputs[key], + crop_params) + elif isinstance(inputs, (list, tuple)): + inputs[i] = getattr(self, 'apply_' + key)(inputs[i], + crop_params) + else: + crop_params = self._get_params(inputs) + inputs = self.apply_image(inputs, crop_params) + + if isinstance(inputs, list): + inputs = tuple(inputs) + return inputs + + +@TRANSFORMS.register() +class RandomHorizontalFlip(Transform): + """Horizontally flip the input data randomly with a given probability. + + Args: + prob (float): Probability of the input data being flipped. Default: 0.5 + """ + def __init__(self, prob=0.5, keys=None): + super().__init__() + self._set_attributes(locals()) + + def apply_image(self, img): + if np.random.random() < self.prob: + return F.flip(img, code=1) + return img + + +# import paddle +# paddle.vision.transforms.RandomHorizontalFlip + + +@TRANSFORMS.register() +class PairedRandomHorizontalFlip(RandomHorizontalFlip): + def __init__(self, prob=0.5, keys=None): + super().__init__() + self._set_attributes(locals()) + + def apply_image(self, img, flip): + if flip: + return F.flip(img, code=1) + return img + + def __call__(self, inputs): + if isinstance(inputs, tuple): + inputs = list(inputs) + flip = np.random.random() < self.prob + if self.keys is not None: + + for i, key in enumerate(self.keys): + if isinstance(inputs, dict): + inputs[key] = getattr(self, 'apply_' + key)(inputs[key], + flip) + elif isinstance(inputs, (list, tuple)): + inputs[i] = getattr(self, 'apply_' + key)(inputs[i], flip) + else: + inputs = self.apply_image(inputs, flip) + + if isinstance(inputs, list): + inputs = tuple(inputs) + + return inputs + + +@TRANSFORMS.register() +class Normalize(Transform): + """Normalize the input data with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, + this transform will normalize each channel of the input data. + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (int|float|list): Sequence of means for each channel. + std (int|float|list): Sequence of standard deviations for each channel. + + """ + def __init__(self, mean=0.0, std=1.0, keys=None): + super().__init__() + self._set_attributes(locals()) + + if isinstance(mean, numbers.Number): + mean = [mean, mean, mean] + + if isinstance(std, numbers.Number): + std = [std, std, std] + + self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1) + self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1) + + def apply_image(self, img): + return (img - self.mean) / self.std + + +@TRANSFORMS.register() +class Permute(Transform): + """Change input data to a target mode. + For example, most transforms use HWC mode image, + while the Neural Network might use CHW mode input tensor. + Input image should be HWC mode and an instance of numpy.ndarray. + + Args: + mode (str): Output mode of input. Default: "CHW". + to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True. + + """ + def __init__(self, mode="CHW", to_rgb=True, keys=None): + super().__init__() + self._set_attributes(locals()) + assert mode in [ + "CHW" + ], "Only support 'CHW' mode, but received mode: {}".format(mode) + self.mode = mode + self.to_rgb = to_rgb + + def apply_image(self, img): + if self.to_rgb: + img = img[..., ::-1] + if self.mode == "CHW": + return img.transpose((2, 0, 1)) + return img + + +# import paddle +# paddle.vision.transforms.Normalize +# TRANSFORMS.register(T.Normalize) + + class Crop(): def __init__(self, pos, size): self.pos = pos @@ -35,6 +286,6 @@ class Crop(): x, y = self.pos th = tw = self.size if (ow > tw or oh > th): - return img[y: y + th, x: x + tw] + return img[y:y + th, x:x + tw] - return img \ No newline at end of file + return img diff --git a/ppgan/datasets/unpaired_dataset.py b/ppgan/datasets/unpaired_dataset.py index 5cabc5391b84e9f6aa55e0925d4202c7b3d09418..45a7c4f47fccdbb2437775963fbf25f3a1a81134 100644 --- a/ppgan/datasets/unpaired_dataset.py +++ b/ppgan/datasets/unpaired_dataset.py @@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform from .image_folder import make_dataset from .builder import DATASETS +from .transforms.builder import build_transforms @DATASETS.register() class UnpairedDataset(BaseDataset): """ """ - def __init__(self, cfg): """Initialize this dataset class. @@ -19,18 +19,26 @@ class UnpairedDataset(BaseDataset): cfg (dict) -- stores all the experiment flags """ BaseDataset.__init__(self, cfg) - self.dir_A = os.path.join(cfg.dataroot, cfg.phase + 'A') # create a path '/path/to/data/trainA' - self.dir_B = os.path.join(cfg.dataroot, cfg.phase + 'B') # create a path '/path/to/data/trainB' + self.dir_A = os.path.join(cfg.dataroot, cfg.phase + + 'A') # create a path '/path/to/data/trainA' + self.dir_B = os.path.join(cfg.dataroot, cfg.phase + + 'B') # create a path '/path/to/data/trainB' - self.A_paths = sorted(make_dataset(self.dir_A, cfg.max_dataset_size)) # load images from '/path/to/data/trainA' - self.B_paths = sorted(make_dataset(self.dir_B, cfg.max_dataset_size)) # load images from '/path/to/data/trainB' + self.A_paths = sorted(make_dataset( + self.dir_A, + cfg.max_dataset_size)) # load images from '/path/to/data/trainA' + self.B_paths = sorted(make_dataset( + self.dir_B, + cfg.max_dataset_size)) # load images from '/path/to/data/trainB' self.A_size = len(self.A_paths) # get the size of dataset A self.B_size = len(self.B_paths) # get the size of dataset B btoA = self.cfg.direction == 'BtoA' - input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image - output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image - self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1)) - self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1)) + input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image + output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image + # self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1)) + # self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1)) + self.transform_A = build_transforms(self.cfg.transforms) + self.transform_B = build_transforms(self.cfg.transforms) self.reset_paths() @@ -49,10 +57,11 @@ class UnpairedDataset(BaseDataset): A_paths (str) -- image paths B_paths (str) -- image paths """ - A_path = self.A_paths[index % self.A_size] # make sure index is within then range - if self.cfg.serial_batches: # make sure index is within then range + A_path = self.A_paths[ + index % self.A_size] # make sure index is within then range + if self.cfg.serial_batches: # make sure index is within then range index_B = index % self.B_size - else: # randomize the index for domain B to avoid fixed pairs. + else: # randomize the index for domain B to avoid fixed pairs. index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] diff --git a/ppgan/models/builder.py b/ppgan/models/builder.py index bd2ed58f096679a00549d53b82e9e603e7208433..607f4e915f43eb85b0032981e70081ba03cb2a8c 100644 --- a/ppgan/models/builder.py +++ b/ppgan/models/builder.py @@ -2,18 +2,9 @@ import paddle from ..utils.registry import Registry - MODELS = Registry("MODEL") def build_model(cfg): - # dataset = MODELS.get(cfg.MODEL.name)(cfg.MODEL) - # place = paddle.CUDAPlace(0) - # dataloader = paddle.io.DataLoader(dataset, - # batch_size=1, #opt.batch_size, - # places=place, - # shuffle=True, #not opt.serial_batches, - # num_workers=0)#int(opt.num_threads)) model = MODELS.get(cfg.model.name)(cfg) return model - # pass \ No newline at end of file diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index 3786b444ba7b73d98eb31d8b5bfca117e8567dd1..e9ff9a28d45fc8eb8916eaf254a8a35012440ab4 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -77,8 +77,8 @@ class Pix2PixModel(BaseModel): """ AtoB = self.opt.dataset.train.direction == 'AtoB' - self.real_A = paddle.to_tensor(input['A' if AtoB else 'B']) - self.real_B = paddle.to_tensor(input['B' if AtoB else 'A']) + self.real_A = paddle.to_variable(input['A' if AtoB else 'B']) + self.real_B = paddle.to_variable(input['B' if AtoB else 'A']) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self):