未验证 提交 9ce257c6 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #31 from LielinJiang/transforms

Reproduce transforms module
dataset_params:
root_dir: data/vox-png
frame_shape: [256, 256, 3]
id_sampling: True
pairs_list: data/vox256.csv
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
model_params:
common_params:
num_kp: 10
......@@ -42,42 +26,3 @@ model_params:
max_features: 512
num_blocks: 4
sn: True
train_params:
num_epochs: 100
num_repeats: 75
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 40
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 50
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 0
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib
matplotlib.use('Agg')
import os
......@@ -5,20 +19,20 @@ import sys
import yaml
import pickle
from argparse import ArgumentParser
from tqdm import tqdm
import imageio
import numpy as np
from skimage.transform import resize
from tqdm import tqdm
from skimage import img_as_ubyte
import paddle
from argparse import ArgumentParser
from skimage.transform import resize
from scipy.spatial import ConvexHull
from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator
from ppgan.modules.keypoint_detector import KPDetector
from ppgan.utils.animate import normalize_kp
from scipy.spatial import ConvexHull
import paddle
paddle.disable_static()
if sys.version_info[0] < 3:
......@@ -60,8 +74,7 @@ def make_animation(source_image,
predictions = []
source = paddle.to_tensor(source_image[np.newaxis].astype(
np.float32)).transpose([0, 3, 1, 2])
# if not cpu:
# source = source.cuda()
driving = paddle.to_tensor(
np.array(driving_video)[np.newaxis].astype(np.float32)).transpose(
[0, 4, 1, 2, 3])
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('.')
......
......@@ -36,16 +36,18 @@ dataset:
output_nc: 3
serial_batches: False
pool_size: 50
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
- name: RandomCrop
output_size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Permute
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test:
name: SingleDataset
dataroot: data/cityscapes/testB
......@@ -55,17 +57,14 @@ dataset:
output_nc: 3
serial_batches: False
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
- name: Permute
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
......
......@@ -35,16 +35,18 @@ dataset:
output_nc: 3
serial_batches: False
pool_size: 50
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
- name: RandomCrop
output_size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Permute
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test:
name: SingleDataset
dataroot: data/horse2zebra/testA
......@@ -55,15 +57,14 @@ dataset:
serial_batches: False
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transform:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
- name: Permute
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
......
......@@ -33,16 +33,23 @@ dataset:
output_nc: 3
serial_batches: False
pool_size: 0
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
output_size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test:
name: PairedDataset
dataroot: data/cityscapes/
......@@ -53,16 +60,18 @@ dataset:
output_nc: 3
serial_batches: True
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer:
name: Adam
......
......@@ -32,16 +32,23 @@ dataset:
output_nc: 3
serial_batches: False
pool_size: 0
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
output_size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test:
name: PairedDataset
dataroot: data/cityscapes/
......@@ -52,16 +59,17 @@ dataset:
output_nc: 3
serial_batches: True
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer:
name: Adam
......
......@@ -32,16 +32,23 @@ dataset:
output_nc: 3
serial_batches: False
pool_size: 0
transform:
load_size: 286
crop_size: 256
preprocess: resize_and_crop
no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
output_size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test:
name: PairedDataset
dataroot: data/facades/
......@@ -52,16 +59,17 @@ dataset:
output_nc: 3
serial_batches: True
pool_size: 50
transform:
load_size: 256
crop_size: 256
preprocess: resize_and_crop
no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
transforms:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer:
name: Adam
......
......@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform
from .image_folder import make_dataset
from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register()
class PairedDataset(BaseDataset):
"""A dataset class for paired image dataset.
"""
def __init__(self, cfg):
"""Initialize this dataset class.
......@@ -19,11 +19,14 @@ class PairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags
"""
BaseDataset.__init__(self, cfg)
self.dir_AB = os.path.join(cfg.dataroot, cfg.phase) # get the image directory
self.AB_paths = sorted(make_dataset(self.dir_AB, cfg.max_dataset_size)) # get image paths
assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image
self.dir_AB = os.path.join(cfg.dataroot,
cfg.phase) # get the image directory
self.AB_paths = sorted(make_dataset(
self.dir_AB, cfg.max_dataset_size)) # get image paths
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
self.transforms = build_transforms(cfg.transforms)
def __getitem__(self, index):
"""Return a data point and its metadata information.
......@@ -49,27 +52,11 @@ class PairedDataset(BaseDataset):
A = AB[:h, :w2, :]
B = AB[:h, w2:, :]
# apply the same transform to both A and B
# transform_params = get_params(self.opt, A.size)
transform_params = get_params(self.cfg.transform, (w2, h))
A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1))
B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1))
A = A_transform(A)
B = B_transform(B)
A, B = self.transforms((A, B))
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.AB_paths)
def get_path_by_indexs(self, indexs):
if isinstance(indexs, paddle.Variable):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
current_paths.append(self.AB_paths[index])
return current_paths
from .transforms import RandomCrop, Resize, RandomHorizontalFlip, PairedRandomCrop, PairedRandomHorizontalFlip, Normalize, Permute
import copy
import traceback
import paddle
from ...utils.registry import Registry
TRANSFORMS = Registry("TRANSFORMS")
class Compose(object):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
for f in self.transforms:
try:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def build_transforms(cfg):
transforms = []
for trans_cfg in cfg:
temp_trans_cfg = copy.deepcopy(trans_cfg)
name = temp_trans_cfg.pop('name')
transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg))
transforms = Compose(transforms)
return transforms
import sys
import random
import numbers
import collections
import numpy as np
from paddle.utils import try_import
import paddle.vision.transforms.functional as F
class RandomCrop(object):
from .builder import TRANSFORMS
def __init__(self, output_size):
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
class Transform():
def _set_attributes(self, args):
"""
Set attributes from the input list of parameters.
Args:
args (list): list of parameters.
"""
if args:
for k, v in args.items():
if k != "self" and not k.startswith("_"):
setattr(self, k, v)
def apply_image(self, input):
raise NotImplementedError
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key])
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i])
else:
inputs = self.apply_image(inputs)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Resize(Transform):
"""Resize the input Image to the given size.
Args:
size (int|list|tuple): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Interpolation mode of resize. Default: 1.
0 : cv2.INTER_NEAREST
1 : cv2.INTER_LINEAR
2 : cv2.INTER_CUBIC
3 : cv2.INTER_AREA
4 : cv2.INTER_LANCZOS4
5 : cv2.INTER_LINEAR_EXACT
7 : cv2.INTER_MAX
8 : cv2.WARP_FILL_OUTLIERS
16: cv2.WARP_INVERSE_MAP
"""
def __init__(self, size, interpolation=1, keys=None):
super().__init__()
assert isinstance(size, int) or (isinstance(size, Iterable)
and len(size) == 2)
self._set_attributes(locals())
if isinstance(self.size, Iterable):
self.size = tuple(size)
def apply_image(self, img):
return F.resize(img, self.size, self.interpolation)
@TRANSFORMS.register()
class RandomCrop(Transform):
def __init__(self, output_size, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
......@@ -19,12 +105,162 @@ class RandomCrop(object):
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img):
def apply_image(self, img):
i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w]
return cropped_img
@TRANSFORMS.register()
class PairedRandomCrop(RandomCrop):
def __init__(self, output_size, keys=None):
super().__init__(output_size, keys)
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def apply_image(self, img, crop_prams=None):
if crop_prams is not None:
i, j, h, w = crop_prams
else:
i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w]
return cropped_img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
if isinstance(inputs, dict):
crop_params = self._get_params(inputs[self.keys[0]])
elif isinstance(inputs, (list, tuple)):
crop_params = self._get_params(inputs[0])
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
crop_params)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i],
crop_params)
else:
crop_params = self._get_params(inputs)
inputs = self.apply_image(inputs, crop_params)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class RandomHorizontalFlip(Transform):
"""Horizontally flip the input data randomly with a given probability.
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
"""
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=1)
return img
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img, flip):
if flip:
return F.flip(img, code=1)
return img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
flip = np.random.random() < self.prob
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
flip)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i], flip)
else:
inputs = self.apply_image(inputs, flip)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Normalize(Transform):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
"""
def __init__(self, mean=0.0, std=1.0, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(mean, numbers.Number):
mean = [mean, mean, mean]
if isinstance(std, numbers.Number):
std = [std, std, std]
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
def apply_image(self, img):
return (img - self.mean) / self.std
@TRANSFORMS.register()
class Permute(Transform):
"""Change input data to a target mode.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
Input image should be HWC mode and an instance of numpy.ndarray.
Args:
mode (str): Output mode of input. Default: "CHW".
to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True.
"""
def __init__(self, mode="CHW", to_rgb=True, keys=None):
super().__init__()
self._set_attributes(locals())
assert mode in [
"CHW"
], "Only support 'CHW' mode, but received mode: {}".format(mode)
self.mode = mode
self.to_rgb = to_rgb
def apply_image(self, img):
if self.to_rgb:
img = img[..., ::-1]
if self.mode == "CHW":
return img.transpose((2, 0, 1))
return img
class Crop():
def __init__(self, pos, size):
self.pos = pos
......@@ -35,6 +271,6 @@ class Crop():
x, y = self.pos
th = tw = self.size
if (ow > tw or oh > th):
return img[y: y + th, x: x + tw]
return img[y:y + th, x:x + tw]
return img
\ No newline at end of file
return img
......@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset
from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register()
class UnpairedDataset(BaseDataset):
"""
"""
def __init__(self, cfg):
"""Initialize this dataset class.
......@@ -19,18 +19,25 @@ class UnpairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags
"""
BaseDataset.__init__(self, cfg)
self.dir_A = os.path.join(cfg.dataroot, cfg.phase + 'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(cfg.dataroot, cfg.phase + 'B') # create a path '/path/to/data/trainB'
self.dir_A = os.path.join(cfg.dataroot, cfg.phase +
'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(cfg.dataroot, cfg.phase +
'B') # create a path '/path/to/data/trainB'
self.A_paths = sorted(make_dataset(self.dir_A, cfg.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, cfg.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_paths = sorted(make_dataset(
self.dir_A,
cfg.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(
self.dir_B,
cfg.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
btoA = self.cfg.direction == 'BtoA'
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1))
self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1))
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
self.transform_A = build_transforms(self.cfg.transforms)
self.transform_B = build_transforms(self.cfg.transforms)
self.reset_paths()
......@@ -49,10 +56,11 @@ class UnpairedDataset(BaseDataset):
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
if self.cfg.serial_batches: # make sure index is within then range
A_path = self.A_paths[
index % self.A_size] # make sure index is within then range
if self.cfg.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
......
......@@ -2,18 +2,9 @@ import paddle
from ..utils.registry import Registry
MODELS = Registry("MODEL")
def build_model(cfg):
# dataset = MODELS.get(cfg.MODEL.name)(cfg.MODEL)
# place = paddle.CUDAPlace(0)
# dataloader = paddle.io.DataLoader(dataset,
# batch_size=1, #opt.batch_size,
# places=place,
# shuffle=True, #not opt.serial_batches,
# num_workers=0)#int(opt.num_threads))
model = MODELS.get(cfg.model.name)(cfg)
return model
# pass
\ No newline at end of file
import os
from tqdm import tqdm
import numpy as np
from scipy.spatial import ConvexHull
import paddle
import imageio
from scipy.spatial import ConvexHull
import numpy as np
def normalize_kp(kp_source,
kp_driving,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册