diff --git a/applications/first_order_model/configs/vox-256.yaml b/applications/first_order_model/configs/vox-256.yaml index abfe9a23949aea62f9b4b7772d3e987f253ecd62..1f38c49af91374323beeed7b027d1fb324e371fb 100644 --- a/applications/first_order_model/configs/vox-256.yaml +++ b/applications/first_order_model/configs/vox-256.yaml @@ -1,19 +1,3 @@ -dataset_params: - root_dir: data/vox-png - frame_shape: [256, 256, 3] - id_sampling: True - pairs_list: data/vox256.csv - augmentation_params: - flip_param: - horizontal_flip: True - time_flip: True - jitter_param: - brightness: 0.1 - contrast: 0.1 - saturation: 0.1 - hue: 0.1 - - model_params: common_params: num_kp: 10 @@ -42,42 +26,3 @@ model_params: max_features: 512 num_blocks: 4 sn: True - -train_params: - num_epochs: 100 - num_repeats: 75 - epoch_milestones: [60, 90] - lr_generator: 2.0e-4 - lr_discriminator: 2.0e-4 - lr_kp_detector: 2.0e-4 - batch_size: 40 - scales: [1, 0.5, 0.25, 0.125] - checkpoint_freq: 50 - transform_params: - sigma_affine: 0.05 - sigma_tps: 0.005 - points_tps: 5 - loss_weights: - generator_gan: 0 - discriminator_gan: 1 - feature_matching: [10, 10, 10, 10] - perceptual: [10, 10, 10, 10, 10] - equivariance_value: 10 - equivariance_jacobian: 10 - -reconstruction_params: - num_videos: 1000 - format: '.mp4' - -animate_params: - num_pairs: 50 - format: '.mp4' - normalization_params: - adapt_movement_scale: False - use_relative_movement: True - use_relative_jacobian: True - -visualizer_params: - kp_size: 5 - draw_border: True - colormap: 'gist_rainbow' diff --git a/applications/tools/first-order-demo.py b/applications/tools/first-order-demo.py index 40e6c1a35549aaeaf1181d23c47c3f968fe94f21..5605bf54078b992499afb341cfd5173bb778db74 100644 --- a/applications/tools/first-order-demo.py +++ b/applications/tools/first-order-demo.py @@ -1,3 +1,17 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import matplotlib matplotlib.use('Agg') import os @@ -5,20 +19,20 @@ import sys import yaml import pickle -from argparse import ArgumentParser -from tqdm import tqdm - import imageio import numpy as np -from skimage.transform import resize + +from tqdm import tqdm from skimage import img_as_ubyte -import paddle +from argparse import ArgumentParser +from skimage.transform import resize +from scipy.spatial import ConvexHull from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator from ppgan.modules.keypoint_detector import KPDetector from ppgan.utils.animate import normalize_kp -from scipy.spatial import ConvexHull +import paddle paddle.disable_static() if sys.version_info[0] < 3: @@ -60,8 +74,7 @@ def make_animation(source_image, predictions = [] source = paddle.to_tensor(source_image[np.newaxis].astype( np.float32)).transpose([0, 3, 1, 2]) - # if not cpu: - # source = source.cuda() + driving = paddle.to_tensor( np.array(driving_video)[np.newaxis].astype(np.float32)).transpose( [0, 4, 1, 2, 3]) diff --git a/applications/tools/video-enhance.py b/applications/tools/video-enhance.py index 04ece7689d33f37c111e5f5acf2c20969f83c2bd..5fb8dbecce128a0cc5447a3e708bf3bfce57768e 100644 --- a/applications/tools/video-enhance.py +++ b/applications/tools/video-enhance.py @@ -1,3 +1,17 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import sys sys.path.append('.') diff --git a/configs/cyclegan_cityscapes.yaml b/configs/cyclegan_cityscapes.yaml index f74d9e3bdd35e3a521b3056a5b268f67bba2e406..97bf179e56d435178af468a39ede578878e826f1 100644 --- a/configs/cyclegan_cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -36,16 +36,18 @@ dataset: output_nc: 3 serial_batches: False pool_size: 50 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + interpolation: 2 #cv2.INTER_CUBIC + - name: RandomCrop + output_size: [256, 256] + - name: RandomHorizontalFlip + prob: 0.5 + - name: Permute + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] test: name: SingleDataset dataroot: data/cityscapes/testB @@ -55,17 +57,14 @@ dataset: output_nc: 3 serial_batches: False pool_size: 50 - transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) - + transforms: + - name: Resize + size: [256, 256] + interpolation: 2 #cv2.INTER_CUBIC + - name: Permute + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] optimizer: name: Adam diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index 0e845bd5183428f7c166bae300f74757406c07f5..c86cd8c3440203a53269d71b9c627c0dafe351f8 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -35,16 +35,18 @@ dataset: output_nc: 3 serial_batches: False pool_size: 50 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + interpolation: 2 #cv2.INTER_CUBIC + - name: RandomCrop + output_size: [256, 256] + - name: RandomHorizontalFlip + prob: 0.5 + - name: Permute + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] test: name: SingleDataset dataroot: data/horse2zebra/testA @@ -55,15 +57,14 @@ dataset: serial_batches: False pool_size: 50 transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transform: + - name: Resize + size: [256, 256] + interpolation: 2 #cv2.INTER_CUBIC + - name: Permute + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] optimizer: name: Adam diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 5919ff2e5a5c2c267a9204d117dc7aba5fb245a7..5a6dd3bbce4f49178a12495c294cf1eef2778071 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -33,16 +33,23 @@ dataset: output_nc: 3 serial_batches: False pool_size: 0 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + interpolation: 2 #cv2.INTER_CUBIC + keys: [image, image] + - name: PairedRandomCrop + output_size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset dataroot: data/cityscapes/ @@ -53,16 +60,18 @@ dataset: output_nc: 3 serial_batches: True pool_size: 50 - transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [256, 256] + interpolation: 2 #cv2.INTER_CUBIC + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] + optimizer: name: Adam diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index 20f494c6fb13690254dd2d047df8c8970615ebff..41279f0be0737f0230b6b5539700907437289391 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -32,16 +32,23 @@ dataset: output_nc: 3 serial_batches: False pool_size: 0 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + interpolation: 2 #cv2.INTER_CUBIC + keys: [image, image] + - name: PairedRandomCrop + output_size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset dataroot: data/cityscapes/ @@ -52,16 +59,17 @@ dataset: output_nc: 3 serial_batches: True pool_size: 50 - transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [256, 256] + interpolation: 2 #cv2.INTER_CUBIC + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] optimizer: name: Adam diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 31b5f145dccdfd75bbdcd14c3fa896676d729037..37e89ce120185100a6146fa5f550334f522eef42 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -32,16 +32,23 @@ dataset: output_nc: 3 serial_batches: False pool_size: 0 - transform: - load_size: 286 - crop_size: 256 - preprocess: resize_and_crop - no_flip: False - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [286, 286] + interpolation: 2 #cv2.INTER_CUBIC + keys: [image, image] + - name: PairedRandomCrop + output_size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset dataroot: data/facades/ @@ -52,16 +59,17 @@ dataset: output_nc: 3 serial_batches: True pool_size: 50 - transform: - load_size: 256 - crop_size: 256 - preprocess: resize_and_crop - no_flip: True - normalize: - mean: - (127.5, 127.5, 127.5) - std: - (127.5, 127.5, 127.5) + transforms: + - name: Resize + size: [256, 256] + interpolation: 2 #cv2.INTER_CUBIC + keys: [image, image] + - name: Permute + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] optimizer: name: Adam diff --git a/ppgan/datasets/paired_dataset.py b/ppgan/datasets/paired_dataset.py index 368f8371178ab771d3139103992a97abc3ee0fe8..4a68bfab210736c256389ca02c6db804ac608fe4 100644 --- a/ppgan/datasets/paired_dataset.py +++ b/ppgan/datasets/paired_dataset.py @@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform from .image_folder import make_dataset from .builder import DATASETS +from .transforms.builder import build_transforms @DATASETS.register() class PairedDataset(BaseDataset): """A dataset class for paired image dataset. """ - def __init__(self, cfg): """Initialize this dataset class. @@ -19,11 +19,14 @@ class PairedDataset(BaseDataset): cfg (dict) -- stores all the experiment flags """ BaseDataset.__init__(self, cfg) - self.dir_AB = os.path.join(cfg.dataroot, cfg.phase) # get the image directory - self.AB_paths = sorted(make_dataset(self.dir_AB, cfg.max_dataset_size)) # get image paths - assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image + self.dir_AB = os.path.join(cfg.dataroot, + cfg.phase) # get the image directory + self.AB_paths = sorted(make_dataset( + self.dir_AB, cfg.max_dataset_size)) # get image paths + self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc + self.transforms = build_transforms(cfg.transforms) def __getitem__(self, index): """Return a data point and its metadata information. @@ -49,27 +52,11 @@ class PairedDataset(BaseDataset): A = AB[:h, :w2, :] B = AB[:h, w2:, :] - # apply the same transform to both A and B - # transform_params = get_params(self.opt, A.size) - transform_params = get_params(self.cfg.transform, (w2, h)) - - A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1)) - B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1)) - - A = A_transform(A) - B = B_transform(B) + A, B = self.transforms((A, B)) return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} def __len__(self): """Return the total number of images in the dataset.""" return len(self.AB_paths) - - def get_path_by_indexs(self, indexs): - if isinstance(indexs, paddle.Variable): - indexs = indexs.numpy() - current_paths = [] - for index in indexs: - current_paths.append(self.AB_paths[index]) - return current_paths diff --git a/ppgan/datasets/transforms/__init__.py b/ppgan/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7734481de315882ca29235ddb4b63aa8d4e7a58c --- /dev/null +++ b/ppgan/datasets/transforms/__init__.py @@ -0,0 +1 @@ +from .transforms import RandomCrop, Resize, RandomHorizontalFlip, PairedRandomCrop, PairedRandomHorizontalFlip, Normalize, Permute diff --git a/ppgan/datasets/transforms/builder.py b/ppgan/datasets/transforms/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfc88a2827a82f563571d002428b49a0ce9e977 --- /dev/null +++ b/ppgan/datasets/transforms/builder.py @@ -0,0 +1,46 @@ +import copy +import traceback +import paddle +from ...utils.registry import Registry + +TRANSFORMS = Registry("TRANSFORMS") + + +class Compose(object): + """ + Composes several transforms together use for composing list of transforms + together for a dataset transform. + + Args: + transforms (list): List of transforms to compose. + + Returns: + A compose object which is callable, __call__ for this Compose + object will call each given :attr:`transforms` sequencely. + + """ + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, data): + for f in self.transforms: + try: + data = f(data) + except Exception as e: + stack_info = traceback.format_exc() + print("fail to perform transform [{}] with error: " + "{} and stack:\n{}".format(f, e, str(stack_info))) + raise e + return data + + +def build_transforms(cfg): + transforms = [] + + for trans_cfg in cfg: + temp_trans_cfg = copy.deepcopy(trans_cfg) + name = temp_trans_cfg.pop('name') + transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg)) + + transforms = Compose(transforms) + return transforms diff --git a/ppgan/datasets/transforms/transforms.py b/ppgan/datasets/transforms/transforms.py index fa54da01bb23b50f71ea47b795cb67d595ce76ef..f2b21564dc9d651e793028656f98d0360687a7bd 100644 --- a/ppgan/datasets/transforms/transforms.py +++ b/ppgan/datasets/transforms/transforms.py @@ -1,9 +1,95 @@ +import sys import random +import numbers +import collections +import numpy as np +from paddle.utils import try_import +import paddle.vision.transforms.functional as F -class RandomCrop(object): +from .builder import TRANSFORMS - def __init__(self, output_size): +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + + +class Transform(): + def _set_attributes(self, args): + """ + Set attributes from the input list of parameters. + + Args: + args (list): list of parameters. + """ + if args: + for k, v in args.items(): + if k != "self" and not k.startswith("_"): + setattr(self, k, v) + + def apply_image(self, input): + raise NotImplementedError + + def __call__(self, inputs): + if isinstance(inputs, tuple): + inputs = list(inputs) + if self.keys is not None: + for i, key in enumerate(self.keys): + if isinstance(inputs, dict): + inputs[key] = getattr(self, 'apply_' + key)(inputs[key]) + elif isinstance(inputs, (list, tuple)): + inputs[i] = getattr(self, 'apply_' + key)(inputs[i]) + else: + inputs = self.apply_image(inputs) + + if isinstance(inputs, list): + inputs = tuple(inputs) + + return inputs + + +@TRANSFORMS.register() +class Resize(Transform): + """Resize the input Image to the given size. + + Args: + size (int|list|tuple): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Interpolation mode of resize. Default: 1. + 0 : cv2.INTER_NEAREST + 1 : cv2.INTER_LINEAR + 2 : cv2.INTER_CUBIC + 3 : cv2.INTER_AREA + 4 : cv2.INTER_LANCZOS4 + 5 : cv2.INTER_LINEAR_EXACT + 7 : cv2.INTER_MAX + 8 : cv2.WARP_FILL_OUTLIERS + 16: cv2.WARP_INVERSE_MAP + + """ + def __init__(self, size, interpolation=1, keys=None): + super().__init__() + assert isinstance(size, int) or (isinstance(size, Iterable) + and len(size) == 2) + self._set_attributes(locals()) + if isinstance(self.size, Iterable): + self.size = tuple(size) + + def apply_image(self, img): + return F.resize(img, self.size, self.interpolation) + + +@TRANSFORMS.register() +class RandomCrop(Transform): + def __init__(self, output_size, keys=None): + super().__init__() + self._set_attributes(locals()) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: @@ -19,12 +105,162 @@ class RandomCrop(object): j = random.randint(0, w - tw) return i, j, th, tw - def __call__(self, img): + def apply_image(self, img): i, j, h, w = self._get_params(img) cropped_img = img[i:i + h, j:j + w] return cropped_img +@TRANSFORMS.register() +class PairedRandomCrop(RandomCrop): + def __init__(self, output_size, keys=None): + super().__init__(output_size, keys) + + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + self.output_size = output_size + + def apply_image(self, img, crop_prams=None): + if crop_prams is not None: + i, j, h, w = crop_prams + else: + i, j, h, w = self._get_params(img) + cropped_img = img[i:i + h, j:j + w] + return cropped_img + + def __call__(self, inputs): + if isinstance(inputs, tuple): + inputs = list(inputs) + if self.keys is not None: + if isinstance(inputs, dict): + crop_params = self._get_params(inputs[self.keys[0]]) + elif isinstance(inputs, (list, tuple)): + crop_params = self._get_params(inputs[0]) + + for i, key in enumerate(self.keys): + if isinstance(inputs, dict): + inputs[key] = getattr(self, 'apply_' + key)(inputs[key], + crop_params) + elif isinstance(inputs, (list, tuple)): + inputs[i] = getattr(self, 'apply_' + key)(inputs[i], + crop_params) + else: + crop_params = self._get_params(inputs) + inputs = self.apply_image(inputs, crop_params) + + if isinstance(inputs, list): + inputs = tuple(inputs) + return inputs + + +@TRANSFORMS.register() +class RandomHorizontalFlip(Transform): + """Horizontally flip the input data randomly with a given probability. + + Args: + prob (float): Probability of the input data being flipped. Default: 0.5 + """ + def __init__(self, prob=0.5, keys=None): + super().__init__() + self._set_attributes(locals()) + + def apply_image(self, img): + if np.random.random() < self.prob: + return F.flip(img, code=1) + return img + + +@TRANSFORMS.register() +class PairedRandomHorizontalFlip(RandomHorizontalFlip): + def __init__(self, prob=0.5, keys=None): + super().__init__() + self._set_attributes(locals()) + + def apply_image(self, img, flip): + if flip: + return F.flip(img, code=1) + return img + + def __call__(self, inputs): + if isinstance(inputs, tuple): + inputs = list(inputs) + flip = np.random.random() < self.prob + if self.keys is not None: + + for i, key in enumerate(self.keys): + if isinstance(inputs, dict): + inputs[key] = getattr(self, 'apply_' + key)(inputs[key], + flip) + elif isinstance(inputs, (list, tuple)): + inputs[i] = getattr(self, 'apply_' + key)(inputs[i], flip) + else: + inputs = self.apply_image(inputs, flip) + + if isinstance(inputs, list): + inputs = tuple(inputs) + + return inputs + + +@TRANSFORMS.register() +class Normalize(Transform): + """Normalize the input data with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, + this transform will normalize each channel of the input data. + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (int|float|list): Sequence of means for each channel. + std (int|float|list): Sequence of standard deviations for each channel. + + """ + def __init__(self, mean=0.0, std=1.0, keys=None): + super().__init__() + self._set_attributes(locals()) + + if isinstance(mean, numbers.Number): + mean = [mean, mean, mean] + + if isinstance(std, numbers.Number): + std = [std, std, std] + + self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1) + self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1) + + def apply_image(self, img): + return (img - self.mean) / self.std + + +@TRANSFORMS.register() +class Permute(Transform): + """Change input data to a target mode. + For example, most transforms use HWC mode image, + while the Neural Network might use CHW mode input tensor. + Input image should be HWC mode and an instance of numpy.ndarray. + + Args: + mode (str): Output mode of input. Default: "CHW". + to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True. + + """ + def __init__(self, mode="CHW", to_rgb=True, keys=None): + super().__init__() + self._set_attributes(locals()) + assert mode in [ + "CHW" + ], "Only support 'CHW' mode, but received mode: {}".format(mode) + self.mode = mode + self.to_rgb = to_rgb + + def apply_image(self, img): + if self.to_rgb: + img = img[..., ::-1] + if self.mode == "CHW": + return img.transpose((2, 0, 1)) + return img + + class Crop(): def __init__(self, pos, size): self.pos = pos @@ -35,6 +271,6 @@ class Crop(): x, y = self.pos th = tw = self.size if (ow > tw or oh > th): - return img[y: y + th, x: x + tw] + return img[y:y + th, x:x + tw] - return img \ No newline at end of file + return img diff --git a/ppgan/datasets/unpaired_dataset.py b/ppgan/datasets/unpaired_dataset.py index 5cabc5391b84e9f6aa55e0925d4202c7b3d09418..232f7bdbbecb3c2d2d8aebe76081156567e0816d 100644 --- a/ppgan/datasets/unpaired_dataset.py +++ b/ppgan/datasets/unpaired_dataset.py @@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform from .image_folder import make_dataset from .builder import DATASETS +from .transforms.builder import build_transforms @DATASETS.register() class UnpairedDataset(BaseDataset): """ """ - def __init__(self, cfg): """Initialize this dataset class. @@ -19,18 +19,25 @@ class UnpairedDataset(BaseDataset): cfg (dict) -- stores all the experiment flags """ BaseDataset.__init__(self, cfg) - self.dir_A = os.path.join(cfg.dataroot, cfg.phase + 'A') # create a path '/path/to/data/trainA' - self.dir_B = os.path.join(cfg.dataroot, cfg.phase + 'B') # create a path '/path/to/data/trainB' + self.dir_A = os.path.join(cfg.dataroot, cfg.phase + + 'A') # create a path '/path/to/data/trainA' + self.dir_B = os.path.join(cfg.dataroot, cfg.phase + + 'B') # create a path '/path/to/data/trainB' - self.A_paths = sorted(make_dataset(self.dir_A, cfg.max_dataset_size)) # load images from '/path/to/data/trainA' - self.B_paths = sorted(make_dataset(self.dir_B, cfg.max_dataset_size)) # load images from '/path/to/data/trainB' + self.A_paths = sorted(make_dataset( + self.dir_A, + cfg.max_dataset_size)) # load images from '/path/to/data/trainA' + self.B_paths = sorted(make_dataset( + self.dir_B, + cfg.max_dataset_size)) # load images from '/path/to/data/trainB' self.A_size = len(self.A_paths) # get the size of dataset A self.B_size = len(self.B_paths) # get the size of dataset B btoA = self.cfg.direction == 'BtoA' - input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image - output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image - self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1)) - self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1)) + input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image + output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image + + self.transform_A = build_transforms(self.cfg.transforms) + self.transform_B = build_transforms(self.cfg.transforms) self.reset_paths() @@ -49,10 +56,11 @@ class UnpairedDataset(BaseDataset): A_paths (str) -- image paths B_paths (str) -- image paths """ - A_path = self.A_paths[index % self.A_size] # make sure index is within then range - if self.cfg.serial_batches: # make sure index is within then range + A_path = self.A_paths[ + index % self.A_size] # make sure index is within then range + if self.cfg.serial_batches: # make sure index is within then range index_B = index % self.B_size - else: # randomize the index for domain B to avoid fixed pairs. + else: # randomize the index for domain B to avoid fixed pairs. index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] diff --git a/ppgan/models/builder.py b/ppgan/models/builder.py index bd2ed58f096679a00549d53b82e9e603e7208433..607f4e915f43eb85b0032981e70081ba03cb2a8c 100644 --- a/ppgan/models/builder.py +++ b/ppgan/models/builder.py @@ -2,18 +2,9 @@ import paddle from ..utils.registry import Registry - MODELS = Registry("MODEL") def build_model(cfg): - # dataset = MODELS.get(cfg.MODEL.name)(cfg.MODEL) - # place = paddle.CUDAPlace(0) - # dataloader = paddle.io.DataLoader(dataset, - # batch_size=1, #opt.batch_size, - # places=place, - # shuffle=True, #not opt.serial_batches, - # num_workers=0)#int(opt.num_threads)) model = MODELS.get(cfg.model.name)(cfg) return model - # pass \ No newline at end of file diff --git a/ppgan/utils/animate.py b/ppgan/utils/animate.py index 3ac08d9c33c5045f1b1cf725abc3a8e5ef9bff4a..df3a0e71caab05f86cd6a6fa113c43f9b3308a34 100644 --- a/ppgan/utils/animate.py +++ b/ppgan/utils/animate.py @@ -1,12 +1,8 @@ -import os -from tqdm import tqdm +import numpy as np +from scipy.spatial import ConvexHull import paddle -import imageio -from scipy.spatial import ConvexHull -import numpy as np - def normalize_kp(kp_source, kp_driving,