未验证 提交 9ce257c6 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #31 from LielinJiang/transforms

Reproduce transforms module
dataset_params:
root_dir: data/vox-png
frame_shape: [256, 256, 3]
id_sampling: True
pairs_list: data/vox256.csv
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
model_params: model_params:
common_params: common_params:
num_kp: 10 num_kp: 10
...@@ -42,42 +26,3 @@ model_params: ...@@ -42,42 +26,3 @@ model_params:
max_features: 512 max_features: 512
num_blocks: 4 num_blocks: 4
sn: True sn: True
train_params:
num_epochs: 100
num_repeats: 75
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 40
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 50
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 0
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
import os import os
...@@ -5,20 +19,20 @@ import sys ...@@ -5,20 +19,20 @@ import sys
import yaml import yaml
import pickle import pickle
from argparse import ArgumentParser
from tqdm import tqdm
import imageio import imageio
import numpy as np import numpy as np
from skimage.transform import resize
from tqdm import tqdm
from skimage import img_as_ubyte from skimage import img_as_ubyte
import paddle from argparse import ArgumentParser
from skimage.transform import resize
from scipy.spatial import ConvexHull
from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator
from ppgan.modules.keypoint_detector import KPDetector from ppgan.modules.keypoint_detector import KPDetector
from ppgan.utils.animate import normalize_kp from ppgan.utils.animate import normalize_kp
from scipy.spatial import ConvexHull
import paddle
paddle.disable_static() paddle.disable_static()
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
...@@ -60,8 +74,7 @@ def make_animation(source_image, ...@@ -60,8 +74,7 @@ def make_animation(source_image,
predictions = [] predictions = []
source = paddle.to_tensor(source_image[np.newaxis].astype( source = paddle.to_tensor(source_image[np.newaxis].astype(
np.float32)).transpose([0, 3, 1, 2]) np.float32)).transpose([0, 3, 1, 2])
# if not cpu:
# source = source.cuda()
driving = paddle.to_tensor( driving = paddle.to_tensor(
np.array(driving_video)[np.newaxis].astype(np.float32)).transpose( np.array(driving_video)[np.newaxis].astype(np.float32)).transpose(
[0, 4, 1, 2, 3]) [0, 4, 1, 2, 3])
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
sys.path.append('.') sys.path.append('.')
......
...@@ -36,16 +36,18 @@ dataset: ...@@ -36,16 +36,18 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: False - name: RandomCrop
normalize: output_size: [256, 256]
mean: - name: RandomHorizontalFlip
(127.5, 127.5, 127.5) prob: 0.5
std: - name: Permute
(127.5, 127.5, 127.5) - name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test: test:
name: SingleDataset name: SingleDataset
dataroot: data/cityscapes/testB dataroot: data/cityscapes/testB
...@@ -55,17 +57,14 @@ dataset: ...@@ -55,17 +57,14 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 256 - name: Resize
crop_size: 256 size: [256, 256]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: True - name: Permute
normalize: - name: Normalize
mean: mean: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5) std: [127.5, 127.5, 127.5]
std:
(127.5, 127.5, 127.5)
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -35,16 +35,18 @@ dataset: ...@@ -35,16 +35,18 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: False - name: RandomCrop
normalize: output_size: [256, 256]
mean: - name: RandomHorizontalFlip
(127.5, 127.5, 127.5) prob: 0.5
std: - name: Permute
(127.5, 127.5, 127.5) - name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test: test:
name: SingleDataset name: SingleDataset
dataroot: data/horse2zebra/testA dataroot: data/horse2zebra/testA
...@@ -55,15 +57,14 @@ dataset: ...@@ -55,15 +57,14 @@ dataset:
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transform:
load_size: 256 transform:
crop_size: 256 - name: Resize
preprocess: resize_and_crop size: [256, 256]
no_flip: True interpolation: 2 #cv2.INTER_CUBIC
normalize: - name: Permute
mean: - name: Normalize
(127.5, 127.5, 127.5) mean: [127.5, 127.5, 127.5]
std: std: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5)
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -33,16 +33,23 @@ dataset: ...@@ -33,16 +33,23 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 0 pool_size: 0
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: False keys: [image, image]
normalize: - name: PairedRandomCrop
mean: output_size: [256, 256]
(127.5, 127.5, 127.5) keys: [image, image]
std: - name: PairedRandomHorizontalFlip
(127.5, 127.5, 127.5) prob: 0.5
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/
...@@ -53,16 +60,18 @@ dataset: ...@@ -53,16 +60,18 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: True serial_batches: True
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 256 - name: Resize
crop_size: 256 size: [256, 256]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: True keys: [image, image]
normalize: - name: Permute
mean: keys: [image, image]
(127.5, 127.5, 127.5) - name: Normalize
std: mean: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5) std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -32,16 +32,23 @@ dataset: ...@@ -32,16 +32,23 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 0 pool_size: 0
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: False keys: [image, image]
normalize: - name: PairedRandomCrop
mean: output_size: [256, 256]
(127.5, 127.5, 127.5) keys: [image, image]
std: - name: PairedRandomHorizontalFlip
(127.5, 127.5, 127.5) prob: 0.5
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/
...@@ -52,16 +59,17 @@ dataset: ...@@ -52,16 +59,17 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: True serial_batches: True
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 256 - name: Resize
crop_size: 256 size: [256, 256]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: True keys: [image, image]
normalize: - name: Permute
mean: keys: [image, image]
(127.5, 127.5, 127.5) - name: Normalize
std: mean: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5) std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -32,16 +32,23 @@ dataset: ...@@ -32,16 +32,23 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 0 pool_size: 0
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: False keys: [image, image]
normalize: - name: PairedRandomCrop
mean: output_size: [256, 256]
(127.5, 127.5, 127.5) keys: [image, image]
std: - name: PairedRandomHorizontalFlip
(127.5, 127.5, 127.5) prob: 0.5
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/facades/ dataroot: data/facades/
...@@ -52,16 +59,17 @@ dataset: ...@@ -52,16 +59,17 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: True serial_batches: True
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 256 - name: Resize
crop_size: 256 size: [256, 256]
preprocess: resize_and_crop interpolation: 2 #cv2.INTER_CUBIC
no_flip: True keys: [image, image]
normalize: - name: Permute
mean: keys: [image, image]
(127.5, 127.5, 127.5) - name: Normalize
std: mean: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5) std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform ...@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform
from .image_folder import make_dataset from .image_folder import make_dataset
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class PairedDataset(BaseDataset): class PairedDataset(BaseDataset):
"""A dataset class for paired image dataset. """A dataset class for paired image dataset.
""" """
def __init__(self, cfg): def __init__(self, cfg):
"""Initialize this dataset class. """Initialize this dataset class.
...@@ -19,11 +19,14 @@ class PairedDataset(BaseDataset): ...@@ -19,11 +19,14 @@ class PairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags cfg (dict) -- stores all the experiment flags
""" """
BaseDataset.__init__(self, cfg) BaseDataset.__init__(self, cfg)
self.dir_AB = os.path.join(cfg.dataroot, cfg.phase) # get the image directory self.dir_AB = os.path.join(cfg.dataroot,
self.AB_paths = sorted(make_dataset(self.dir_AB, cfg.max_dataset_size)) # get image paths cfg.phase) # get the image directory
assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image self.AB_paths = sorted(make_dataset(
self.dir_AB, cfg.max_dataset_size)) # get image paths
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
self.transforms = build_transforms(cfg.transforms)
def __getitem__(self, index): def __getitem__(self, index):
"""Return a data point and its metadata information. """Return a data point and its metadata information.
...@@ -49,27 +52,11 @@ class PairedDataset(BaseDataset): ...@@ -49,27 +52,11 @@ class PairedDataset(BaseDataset):
A = AB[:h, :w2, :] A = AB[:h, :w2, :]
B = AB[:h, w2:, :] B = AB[:h, w2:, :]
# apply the same transform to both A and B # apply the same transform to both A and B
# transform_params = get_params(self.opt, A.size) A, B = self.transforms((A, B))
transform_params = get_params(self.cfg.transform, (w2, h))
A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1))
B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1))
A = A_transform(A)
B = B_transform(B)
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
def __len__(self): def __len__(self):
"""Return the total number of images in the dataset.""" """Return the total number of images in the dataset."""
return len(self.AB_paths) return len(self.AB_paths)
def get_path_by_indexs(self, indexs):
if isinstance(indexs, paddle.Variable):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
current_paths.append(self.AB_paths[index])
return current_paths
from .transforms import RandomCrop, Resize, RandomHorizontalFlip, PairedRandomCrop, PairedRandomHorizontalFlip, Normalize, Permute
import copy
import traceback
import paddle
from ...utils.registry import Registry
TRANSFORMS = Registry("TRANSFORMS")
class Compose(object):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
for f in self.transforms:
try:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def build_transforms(cfg):
transforms = []
for trans_cfg in cfg:
temp_trans_cfg = copy.deepcopy(trans_cfg)
name = temp_trans_cfg.pop('name')
transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg))
transforms = Compose(transforms)
return transforms
import sys
import random import random
import numbers
import collections
import numpy as np
from paddle.utils import try_import
import paddle.vision.transforms.functional as F
class RandomCrop(object): from .builder import TRANSFORMS
def __init__(self, output_size): if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
class Transform():
def _set_attributes(self, args):
"""
Set attributes from the input list of parameters.
Args:
args (list): list of parameters.
"""
if args:
for k, v in args.items():
if k != "self" and not k.startswith("_"):
setattr(self, k, v)
def apply_image(self, input):
raise NotImplementedError
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key])
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i])
else:
inputs = self.apply_image(inputs)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Resize(Transform):
"""Resize the input Image to the given size.
Args:
size (int|list|tuple): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Interpolation mode of resize. Default: 1.
0 : cv2.INTER_NEAREST
1 : cv2.INTER_LINEAR
2 : cv2.INTER_CUBIC
3 : cv2.INTER_AREA
4 : cv2.INTER_LANCZOS4
5 : cv2.INTER_LINEAR_EXACT
7 : cv2.INTER_MAX
8 : cv2.WARP_FILL_OUTLIERS
16: cv2.WARP_INVERSE_MAP
"""
def __init__(self, size, interpolation=1, keys=None):
super().__init__()
assert isinstance(size, int) or (isinstance(size, Iterable)
and len(size) == 2)
self._set_attributes(locals())
if isinstance(self.size, Iterable):
self.size = tuple(size)
def apply_image(self, img):
return F.resize(img, self.size, self.interpolation)
@TRANSFORMS.register()
class RandomCrop(Transform):
def __init__(self, output_size, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(output_size, int): if isinstance(output_size, int):
self.output_size = (output_size, output_size) self.output_size = (output_size, output_size)
else: else:
...@@ -19,12 +105,162 @@ class RandomCrop(object): ...@@ -19,12 +105,162 @@ class RandomCrop(object):
j = random.randint(0, w - tw) j = random.randint(0, w - tw)
return i, j, th, tw return i, j, th, tw
def __call__(self, img): def apply_image(self, img):
i, j, h, w = self._get_params(img) i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w] cropped_img = img[i:i + h, j:j + w]
return cropped_img return cropped_img
@TRANSFORMS.register()
class PairedRandomCrop(RandomCrop):
def __init__(self, output_size, keys=None):
super().__init__(output_size, keys)
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def apply_image(self, img, crop_prams=None):
if crop_prams is not None:
i, j, h, w = crop_prams
else:
i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w]
return cropped_img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
if isinstance(inputs, dict):
crop_params = self._get_params(inputs[self.keys[0]])
elif isinstance(inputs, (list, tuple)):
crop_params = self._get_params(inputs[0])
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
crop_params)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i],
crop_params)
else:
crop_params = self._get_params(inputs)
inputs = self.apply_image(inputs, crop_params)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class RandomHorizontalFlip(Transform):
"""Horizontally flip the input data randomly with a given probability.
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
"""
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=1)
return img
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img, flip):
if flip:
return F.flip(img, code=1)
return img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
flip = np.random.random() < self.prob
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
flip)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i], flip)
else:
inputs = self.apply_image(inputs, flip)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Normalize(Transform):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
"""
def __init__(self, mean=0.0, std=1.0, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(mean, numbers.Number):
mean = [mean, mean, mean]
if isinstance(std, numbers.Number):
std = [std, std, std]
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
def apply_image(self, img):
return (img - self.mean) / self.std
@TRANSFORMS.register()
class Permute(Transform):
"""Change input data to a target mode.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
Input image should be HWC mode and an instance of numpy.ndarray.
Args:
mode (str): Output mode of input. Default: "CHW".
to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True.
"""
def __init__(self, mode="CHW", to_rgb=True, keys=None):
super().__init__()
self._set_attributes(locals())
assert mode in [
"CHW"
], "Only support 'CHW' mode, but received mode: {}".format(mode)
self.mode = mode
self.to_rgb = to_rgb
def apply_image(self, img):
if self.to_rgb:
img = img[..., ::-1]
if self.mode == "CHW":
return img.transpose((2, 0, 1))
return img
class Crop(): class Crop():
def __init__(self, pos, size): def __init__(self, pos, size):
self.pos = pos self.pos = pos
...@@ -35,6 +271,6 @@ class Crop(): ...@@ -35,6 +271,6 @@ class Crop():
x, y = self.pos x, y = self.pos
th = tw = self.size th = tw = self.size
if (ow > tw or oh > th): if (ow > tw or oh > th):
return img[y: y + th, x: x + tw] return img[y:y + th, x:x + tw]
return img return img
...@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform ...@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset from .image_folder import make_dataset
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class UnpairedDataset(BaseDataset): class UnpairedDataset(BaseDataset):
""" """
""" """
def __init__(self, cfg): def __init__(self, cfg):
"""Initialize this dataset class. """Initialize this dataset class.
...@@ -19,18 +19,25 @@ class UnpairedDataset(BaseDataset): ...@@ -19,18 +19,25 @@ class UnpairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags cfg (dict) -- stores all the experiment flags
""" """
BaseDataset.__init__(self, cfg) BaseDataset.__init__(self, cfg)
self.dir_A = os.path.join(cfg.dataroot, cfg.phase + 'A') # create a path '/path/to/data/trainA' self.dir_A = os.path.join(cfg.dataroot, cfg.phase +
self.dir_B = os.path.join(cfg.dataroot, cfg.phase + 'B') # create a path '/path/to/data/trainB' 'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(cfg.dataroot, cfg.phase +
'B') # create a path '/path/to/data/trainB'
self.A_paths = sorted(make_dataset(self.dir_A, cfg.max_dataset_size)) # load images from '/path/to/data/trainA' self.A_paths = sorted(make_dataset(
self.B_paths = sorted(make_dataset(self.dir_B, cfg.max_dataset_size)) # load images from '/path/to/data/trainB' self.dir_A,
cfg.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(
self.dir_B,
cfg.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B self.B_size = len(self.B_paths) # get the size of dataset B
btoA = self.cfg.direction == 'BtoA' btoA = self.cfg.direction == 'BtoA'
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1))
self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1)) self.transform_A = build_transforms(self.cfg.transforms)
self.transform_B = build_transforms(self.cfg.transforms)
self.reset_paths() self.reset_paths()
...@@ -49,7 +56,8 @@ class UnpairedDataset(BaseDataset): ...@@ -49,7 +56,8 @@ class UnpairedDataset(BaseDataset):
A_paths (str) -- image paths A_paths (str) -- image paths
B_paths (str) -- image paths B_paths (str) -- image paths
""" """
A_path = self.A_paths[index % self.A_size] # make sure index is within then range A_path = self.A_paths[
index % self.A_size] # make sure index is within then range
if self.cfg.serial_batches: # make sure index is within then range if self.cfg.serial_batches: # make sure index is within then range
index_B = index % self.B_size index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs. else: # randomize the index for domain B to avoid fixed pairs.
......
...@@ -2,18 +2,9 @@ import paddle ...@@ -2,18 +2,9 @@ import paddle
from ..utils.registry import Registry from ..utils.registry import Registry
MODELS = Registry("MODEL") MODELS = Registry("MODEL")
def build_model(cfg): def build_model(cfg):
# dataset = MODELS.get(cfg.MODEL.name)(cfg.MODEL)
# place = paddle.CUDAPlace(0)
# dataloader = paddle.io.DataLoader(dataset,
# batch_size=1, #opt.batch_size,
# places=place,
# shuffle=True, #not opt.serial_batches,
# num_workers=0)#int(opt.num_threads))
model = MODELS.get(cfg.model.name)(cfg) model = MODELS.get(cfg.model.name)(cfg)
return model return model
# pass
\ No newline at end of file
import os import numpy as np
from tqdm import tqdm from scipy.spatial import ConvexHull
import paddle import paddle
import imageio
from scipy.spatial import ConvexHull
import numpy as np
def normalize_kp(kp_source, def normalize_kp(kp_source,
kp_driving, kp_driving,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册