未验证 提交 af0ea59c 编写于 作者: F FNRE 提交者: GitHub

Add training of first order motion (#256)

* add training od first order motion
* Modify codes according to reviews
* modify single person and multi person process of fom model. 
* Git rid of sklearn, skimage
上级 a62d22e3
......@@ -63,6 +63,11 @@ parser.add_argument(
type=str,
default='sfd',
help="face detector to be used, can choose s3fd or blazeface")
parser.add_argument("--multi_person",
dest="multi_person",
action="store_true",
default=False,
help="whether there is only one person in the image or not")
parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
......@@ -72,7 +77,6 @@ if __name__ == "__main__":
if args.cpu:
paddle.set_device('cpu')
predictor = FirstOrderPredictor(output=args.output,
filename=args.filename,
weight_path=args.weight_path,
......@@ -82,5 +86,6 @@ if __name__ == "__main__":
find_best_frame=args.find_best_frame,
best_frame=args.best_frame,
ratio=args.ratio,
face_detector=args.face_detector)
face_detector=args.face_detector,
multi_person=args.multi_person)
predictor.run(args.source_image, args.driving_video)
epochs: 150
output_dir: output_dir
model:
name: FirstOrderModel
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
generator:
name: FirstOrderGenerator
kp_detector_cfg:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_cfg:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator:
name: FirstOrderDiscriminator
discriminator_cfg:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
scales: [1, 0.5, 0.25, 0.125]
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
optimizer:
name: Adam
lr_scheduler:
epoch_milestones: [187500, 281250]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
dataset:
train:
name: FirstOrderDataset
phase: train
dataroot: data/first_order/fashion/
num_repeats: 50
time_flip: True
batch_size: 8
id_sampling: False
frame_shape: [ 256, 256, 3 ]
process_time: False
create_frames_folder: False
num_workers: 4
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
transforms:
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: PairedColorJitter
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
keys: [image, image]
test:
name: FirstOrderDataset
dataroot: data/first_order/fashion/
phase: test
batch_size: 1
num_workers: 1
time_flip: False
id_sampling: False
create_frames_folder: False
frame_shape: [ 256, 256, 3 ]
log_config:
interval: 10
visiual_interval: 10
snapshot_config:
interval: 10
validate:
interval: 31250
......@@ -22,8 +22,6 @@ import pickle
import imageio
import numpy as np
from tqdm import tqdm
from skimage import img_as_ubyte
from skimage.transform import resize
from scipy.spatial import ConvexHull
import paddle
......@@ -47,37 +45,41 @@ class FirstOrderPredictor(BasePredictor):
best_frame=None,
ratio=1.0,
filename='result.mp4',
face_detector='sfd'):
face_detector='sfd',
multi_person=False):
if config is not None and isinstance(config, str):
self.cfg = yaml.load(config, Loader=yaml.SafeLoader)
with open(config) as f:
self.cfg = yaml.load(f, Loader=yaml.SafeLoader)
elif isinstance(config, dict):
self.cfg = config
elif config is None:
self.cfg = {
'model_params': {
'model': {
'common_params': {
'num_kp': 10,
'num_channels': 3,
'estimate_jacobian': True
},
'kp_detector_params': {
'temperature': 0.1,
'block_expansion': 32,
'max_features': 1024,
'scale_factor': 0.25,
'num_blocks': 5
},
'generator_params': {
'block_expansion': 64,
'max_features': 512,
'num_down_blocks': 2,
'num_bottleneck_blocks': 6,
'estimate_occlusion_map': True,
'dense_motion_params': {
'block_expansion': 64,
'generator': {
'kp_detector_cfg': {
'temperature': 0.1,
'block_expansion': 32,
'max_features': 1024,
'num_blocks': 5,
'scale_factor': 0.25
'scale_factor': 0.25,
'num_blocks': 5
},
'generator_cfg': {
'block_expansion': 64,
'max_features': 512,
'num_down_blocks': 2,
'num_bottleneck_blocks': 6,
'estimate_occlusion_map': True,
'dense_motion_params': {
'block_expansion': 64,
'max_features': 1024,
'num_blocks': 5,
'scale_factor': 0.25
}
}
}
}
......@@ -99,28 +101,10 @@ class FirstOrderPredictor(BasePredictor):
self.face_detector = face_detector
self.generator, self.kp_detector = self.load_checkpoints(
self.cfg, self.weight_path)
self.multi_person = multi_person
def run(self, source_image, driving_video):
source_image = imageio.imread(source_image)
bboxes = self.extract_bbox(source_image.copy())
reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
driving_video = [
resize(frame, (256, 256))[..., :3] for frame in driving_video
]
results = []
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = resize(face_image, (256, 256))
def get_prediction(face_image):
if self.find_best_frame or self.best_frame is not None:
i = self.best_frame if self.best_frame is not None else self.find_best_frame_func(
source_image, driving_video)
......@@ -152,7 +136,52 @@ class FirstOrderPredictor(BasePredictor):
self.kp_detector,
relative=self.relative,
adapt_movement_scale=self.adapt_scale)
return predictions
source_image = imageio.imread(source_image)
reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
driving_video = [
cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video
]
results = []
# for single person
if not self.multi_person:
h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (256, 256)) / 255.0
predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions
])
return
bboxes = self.extract_bbox(source_image.copy())
print(str(len(bboxes)) + " persons have been detected")
if len(bboxes) <= 1:
h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (256, 256)) / 255.0
predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions
])
return
# for multi person
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (256, 256)) / 255.0
predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': predictions})
out_frame = []
......@@ -160,7 +189,7 @@ class FirstOrderPredictor(BasePredictor):
for i in range(len(driving_video)):
frame = source_image.copy()
for result in results:
x1, y1, x2, y2 = result['rec']
x1, y1, x2, y2, _ = result['rec']
h = y2 - y1
w = x2 - x1
out = result['predict'][i] * 255.0
......@@ -185,11 +214,12 @@ class FirstOrderPredictor(BasePredictor):
def load_checkpoints(self, config, checkpoint_path):
generator = OcclusionAwareGenerator(
**config['model_params']['generator_params'],
**config['model_params']['common_params'])
**config['model']['generator']['generator_cfg'],
**config['model']['common_params'])
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
kp_detector = KPDetector(
**config['model']['generator']['kp_detector_cfg'],
**config['model']['common_params'])
checkpoint = paddle.load(self.weight_path)
generator.set_state_dict(checkpoint['generator'])
......@@ -269,7 +299,11 @@ class FirstOrderPredictor(BasePredictor):
frame = [image]
predictions = detector.get_detections_for_image(np.array(frame))
person_num = len(predictions)
if person_num == 0:
return np.array([])
results = []
face_boxs = []
h, w, _ = image.shape
for rect in predictions:
bh = rect[3] - rect[1]
......@@ -281,6 +315,37 @@ class FirstOrderPredictor(BasePredictor):
x1 = max(0, cx - int(0.8 * margin))
y2 = min(h, cy + margin)
x2 = min(w, cx + int(0.8 * margin))
results.append([x1, y1, x2, y2])
boxes = np.array(results)
area = (y2 - y1) * (x2 - x1)
results.append([x1, y1, x2, y2, area])
# if a person has more than one bbox, keep the largest one
# maybe greedy will be better?
sorted(results, key=lambda area: area[4], reverse=True)
results_box = [results[0]]
for i in range(1, person_num):
num = len(results_box)
add_person = True
for j in range(num):
pre_person = results_box[j]
iou = self.IOU(pre_person[0], pre_person[1], pre_person[2],
pre_person[3], pre_person[4], results[i][0],
results[i][1], results[i][2], results[i][3],
results[i][4])
if iou > 0.5:
add_person = False
break
if add_person:
results_box.append(results[i])
boxes = np.array(results_box)
return boxes
def IOU(self, ax1, ay1, ax2, ay2, sa, bx1, by1, bx2, by2, sb):
#sa = abs((ax2 - ax1) * (ay2 - ay1))
#sb = abs((bx2 - bx1) * (by2 - by1))
x1, y1 = max(ax1, bx1), max(ay1, by1)
x2, y2 = min(ax2, bx2), min(ay2, by2)
w = x2 - x1
h = y2 - y1
if w < 0 or h < 0:
return 0.0
else:
return 1.0 * w * h / (sa + sb - w * h)
......@@ -22,3 +22,4 @@ from .animeganv2_dataset import AnimeGANV2Dataset
from .wav2lip_dataset import Wav2LipDataset
from .starganv2_dataset import StarGANv2Dataset
from .edvr_dataset import REDSDataset
from .firstorder_dataset import FirstOrderDataset
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import logging
from multiprocessing import Pool
from pathlib import Path
import numpy as np
import tqdm
from imageio import imread, mimread, imwrite
import cv2
from paddle.io import Dataset
from .builder import DATASETS
from .preprocess.builder import build_transforms
import glob, os
POOL_SIZE = 64 # If POOL_SIZE>0 use multiprocessing to extract frames from gif file
@DATASETS.register()
class FirstOrderDataset(Dataset):
def __init__(self, **cfg):
"""Initialize FirstOrder dataset class.
Args:
dataroot (str): Directory of dataset.
phase (str): train or test
num_repeats (int): Number for datasets to repeat
time_flip (bool): whether to exchange the driving image and source image randomly
batch_size (int): dataset batch size
id_sampling (bool): whether to sample person's id
frame_shape (list): image shape
create_frames_folder (bool): if the format of your input datasets is '.mp4', \
you can choose whether to save it with images
num_workers (int): dataset
"""
super(FirstOrderDataset, self).__init__()
self.cfg = cfg
self.frameDataset = FramesDataset(self.cfg)
# create frames folder before 'DatasetRepeater'
if self.cfg['create_frames_folder']:
file_idx_set = [
idx for idx, path in enumerate(self.frameDataset.videos)
if not self.frameDataset.root_dir.joinpath(path).is_dir()
]
file_idx_set = list(file_idx_set)
if len(file_idx_set) != 0:
if POOL_SIZE == 0:
for idx in tqdm.tqdm(file_idx_set,
desc='Extracting frames'):
_ = self.frameDataset[idx]
else:
# multiprocessing
bar = tqdm.tqdm(total=len(file_idx_set),
desc='Extracting frames')
with Pool(POOL_SIZE) as pl:
_p = 0
while _p <= len(file_idx_set) - 1:
_ = pl.map(self.frameDataset.__getitem__,
file_idx_set[_p:_p + POOL_SIZE * 2])
_p = _p + POOL_SIZE * 2
bar.update(POOL_SIZE * 2)
bar.close()
# rewrite video path
self.frameDataset.videos = [
i.with_suffix('') for i in self.frameDataset.videos
]
if self.cfg['phase'] == 'train':
self.outDataset = DatasetRepeater(self.frameDataset,
self.cfg['num_repeats'])
else:
self.outDataset = self.frameDataset
def __len__(self):
return len(self.outDataset)
def __getitem__(self, idx):
return self.outDataset[idx]
def read_video(name: Path, frame_shape=tuple([256, 256, 3]), saveto='folder'):
"""
Read video which can be:
- an image of concatenated frames
- '.mp4' and'.gif'
- folder with videos
"""
if name.is_dir():
frames = sorted(name.iterdir(),
key=lambda x: int(x.with_suffix('').name))
video_array = np.array([imread(path) for path in frames])
elif name.suffix.lower() in ['.gif', '.mp4', '.mov']:
try:
video = mimread(name, memtest=False)
except Exception as err:
logging.error('DataLoading File:%s Msg:%s' % (str(name), str(err)))
return None
# convert to 3-channel image
if video[0].shape[-1] == 4:
video = [i[..., :3] for i in video]
elif video[0].shape[-1] == 1:
video = [np.tile(i, (1, 1, 3)) for i in video]
elif len(video[0].shape) == 2:
video = [np.tile(i[..., np.newaxis], (1, 1, 3)) for i in video]
video_array = np.asarray(video)
video_array_reshape = []
for idx, img in enumerate(video_array):
img = cv2.resize(img, (frame_shape[0], frame_shape[1]))
video_array_reshape.append(img.astype(np.uint8))
video_array_reshape = np.asarray(video_array_reshape)
if saveto == 'folder':
sub_dir = name.with_suffix('')
try:
sub_dir.mkdir()
except FileExistsError:
pass
for idx, img in enumerate(video_array_reshape):
cv2.imwrite(sub_dir.joinpath('%i.png' % idx), img)
name.unlink()
else:
raise Exception("Unknown dataset file extensions %s" % name)
return video_array_reshape
class FramesDataset(Dataset):
"""
Dataset of videos, each video can be represented as:
- an image of concatenated frames
- '.mp4' or '.gif'
- folder with all frames
FramesDataset[i]: obtain sample from i-th video in self.videos
"""
def __init__(self, cfg):
self.root_dir = Path(cfg['dataroot'])
self.videos = None
self.frame_shape = tuple(cfg['frame_shape'])
self.id_sampling = cfg['id_sampling']
self.time_flip = cfg['time_flip']
self.is_train = True if cfg['phase'] == 'train' else False
self.pairs_list = cfg.setdefault('pairs_list', None)
self.create_frames_folder = cfg['create_frames_folder']
self.transform = None
random_seed = 0
assert self.root_dir.joinpath('train').exists()
assert self.root_dir.joinpath('test').exists()
logging.info("Use predefined train-test split.")
if self.id_sampling:
train_videos = {
video.name.split('#')[0]
for video in self.root_dir.joinpath('train').iterdir()
}
train_videos = list(train_videos)
else:
train_videos = list(self.root_dir.joinpath('train').iterdir())
test_videos = list(self.root_dir.joinpath('test').iterdir())
self.root_dir = self.root_dir.joinpath(
'train' if self.is_train else 'test')
if self.is_train:
self.videos = train_videos
self.transform = build_transforms(cfg['transforms'])
else:
self.videos = test_videos
self.transform = None
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
if self.is_train and self.id_sampling:
name = self.videos[idx]
path = Path(
np.random.choice(
glob.glob(os.path.join(self.root_dir, name + '*.mp4'))))
else:
path = self.videos[idx]
video_name = path.name
if self.is_train and path.is_dir():
frames = sorted(path.iterdir(),
key=lambda x: int(x.with_suffix('').name))
num_frames = len(frames)
frame_idx = np.sort(
np.random.choice(num_frames, replace=True, size=2))
video_array = [imread(str(frames[idx])) for idx in frame_idx]
else:
if self.create_frames_folder:
video_array = read_video(path,
frame_shape=self.frame_shape,
saveto='folder')
self.videos[idx] = path.with_suffix(
'') # rename /xx/xx/xx.gif -> /xx/xx/xx
else:
video_array = read_video(path,
frame_shape=self.frame_shape,
saveto=None)
num_frames = len(video_array)
frame_idx = np.sort(
np.random.choice(
num_frames, replace=True,
size=2)) if self.is_train else range(num_frames)
video_array = [video_array[i] for i in frame_idx]
# convert to 3-channel image
if video_array[0].shape[-1] == 4:
video_array = [i[..., :3] for i in video_array]
elif video_array[0].shape[-1] == 1:
video_array = [np.tile(i, (1, 1, 3)) for i in video_array]
elif len(video_array[0].shape) == 2:
video_array = [
np.tile(i[..., np.newaxis], (1, 1, 3)) for i in video_array
]
out = {}
if self.is_train:
if self.transform is not None: #modify
t = self.transform(tuple(video_array))
out['driving'] = t[0].transpose(2, 0, 1).astype(
np.float32) / 255.0
out['source'] = t[1].transpose(2, 0, 1).astype(
np.float32) / 255.0
else:
source = np.array(video_array[0],
dtype='float32') / 255.0 # shape is [H, W, C]
driving = np.array(
video_array[1],
dtype='float32') / 255.0 # shape is [H, W, C]
out['driving'] = driving.transpose(2, 0, 1)
out['source'] = source.transpose(2, 0, 1)
if self.time_flip and np.random.rand() < 0.5: #modify
buf = out['driving']
out['driving'] = out['source']
out['source'] = buf
else:
video = np.stack(video_array, axis=0) / 255.0
out['video'] = video.transpose(3, 0, 1, 2)
out['name'] = video_name
return out
def get_sample(self, idx):
return self.__getitem__(idx)
class DatasetRepeater(Dataset):
"""
Pass several times over the same dataset for better i/o performance
"""
def __init__(self, dataset, num_repeats=100):
self.dataset = dataset
self.num_repeats = num_repeats
def __len__(self):
return self.num_repeats * self.dataset.__len__()
def __getitem__(self, idx):
return self.dataset[idx % self.dataset.__len__()]
......@@ -355,3 +355,59 @@ class ResizeToScale(T.BaseTransform):
def _apply_image(self, image):
return F.resize(image, self.params['taget_size'], self.interpolation)
@TRANSFORMS.register()
class PairedColorJitter(T.BaseTransform):
def __init__(self,
brightness=0,
contrast=0,
saturation=0,
hue=0,
keys=None):
super().__init__(keys=keys)
self.brightness = T.transforms._check_input(brightness, 'brightness')
self.contrast = T.transforms._check_input(contrast, 'contrast')
self.saturation = T.transforms._check_input(saturation, 'saturation')
self.hue = T.transforms._check_input(hue,
'hue',
center=0,
bound=(-0.5, 0.5),
clip_first_on_zero=False)
def _get_params(self, input):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if self.brightness is not None:
brightness = random.uniform(self.brightness[0], self.brightness[1])
f = lambda img: F.adjust_brightness(img, brightness)
transforms.append(f)
if self.contrast is not None:
contrast = random.uniform(self.contrast[0], self.contrast[1])
f = lambda img: F.adjust_contrast(img, contrast)
transforms.append(f)
if self.saturation is not None:
saturation = random.uniform(self.saturation[0], self.saturation[1])
f = lambda img: F.adjust_saturation(img, saturation)
transforms.append(f)
if self.hue is not None:
hue = random.uniform(self.hue[0], self.hue[1])
f = lambda img: F.adjust_hue(img, hue)
transforms.append(f)
random.shuffle(transforms)
return transforms
def _apply_image(self, img):
for f in self.params:
img = f(img)
return img
......@@ -28,3 +28,4 @@ from .wav2lip_model import Wav2LipModel
from .wav2lip_hq_model import Wav2LipModelHq
from .starganv2_model import StarGANv2Model
from .edvr_model import EDVRModel
from .firstorder_model import FirstOrderModel
......@@ -21,3 +21,4 @@ from .discriminator_styleganv2 import StyleGANv2Discriminator
from .syncnet import SyncNetColor
from .wav2lip_disc_qual import Wav2LipDiscQual
from .discriminator_starganv2 import StarGANv2Discriminator
from .discriminator_firstorder import FirstOrderDiscriminator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import DISCRIMINATORS
from ...modules.first_order import ImagePyramide, detach_kp, kp2gaussian
from ...modules.utils import spectral_norm
@DISCRIMINATORS.register()
class FirstOrderDiscriminator(nn.Layer):
"""
Merge all discriminator related updates into single model for better multi-gpu usage
Args:
discriminator_cfg:
scales (list): extract the features of image pyramids
block_expansion (int): block_expansion * (2**i) output features for each block i
max_features (int): input features cannot larger than max_features for encoding images
num_blocks (int): number of blocks for encoding images
sn (bool): whether to use spentral norm
common_params:
num_kp (int): number of keypoints
num_channels (int): image channels
estimate_jacobian (bool): whether to estimate jacobian values of keypoints
train_params:
loss_weights:
discriminator_gan (int): weight of discriminator loss
"""
def __init__(self, discriminator_cfg, common_params, train_params):
super(FirstOrderDiscriminator, self).__init__()
self.discriminator = MultiScaleDiscriminator(**discriminator_cfg,
**common_params)
self.train_params = train_params
self.scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, common_params['num_channels'])
self.loss_weights = train_params['loss_weights']
def forward(self, x, generated):
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'].detach())
kp_driving = generated['kp_driving']
discriminator_maps_generated = self.discriminator(
pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real,
kp=detach_kp(kp_driving))
loss_values = {}
value_total = 0
for scale in self.scales:
key = 'prediction_map_%s' % scale
value = (1 - discriminator_maps_real[key]
)**2 + discriminator_maps_generated[key]**2
value_total += self.loss_weights['discriminator_gan'] * value.mean()
loss_values['disc_gan'] = value_total
return loss_values
class DownBlock2d(nn.Layer):
"""
Simple block for processing video (encoder).
"""
def __init__(self,
in_features,
out_features,
norm=False,
kernel_size=4,
pool=False,
sn=False):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2D(in_features,
out_features,
kernel_size=kernel_size)
if sn:
self.conv = spectral_norm(self.conv)
else:
self.sn = None
if norm:
self.norm = nn.InstanceNorm2D(num_features=out_features,
epsilon=1e-05)
else:
self.norm = None
self.pool = pool
def forward(self, x):
out = x
out = self.conv(out)
if self.norm is not None:
out = self.norm(out)
out = F.leaky_relu(out, 0.2)
if self.pool:
out = F.avg_pool2d(out, kernel_size=2, stride=2, ceil_mode=False)
return out
class Discriminator(nn.Layer):
def __init__(self,
num_channels=3,
block_expansion=64,
num_blocks=4,
max_features=512,
sn=False,
use_kp=False,
num_kp=10,
kp_variance=0.01,
**kwargs):
super(Discriminator, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(
DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(
max_features, block_expansion * (2**i)),
min(max_features, block_expansion * (2**(i + 1))),
norm=(i != 0),
kernel_size=4,
pool=(i != num_blocks - 1),
sn=sn))
self.down_blocks = nn.LayerList(down_blocks)
self.conv = nn.Conv2D(self.down_blocks[len(self.down_blocks) -
1].conv.parameters()[0].shape[0],
1,
kernel_size=1)
if sn:
self.conv = spectral_norm(self.conv)
else:
self.sn = None
self.use_kp = use_kp
self.kp_variance = kp_variance
def forward(self, x, kp=None):
feature_maps = []
out = x
if self.use_kp:
heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
out = paddle.concat([out, heatmap], axis=1)
for down_block in self.down_blocks:
out = down_block(out)
feature_maps.append(out)
out = feature_maps[-1]
prediction_map = self.conv(out)
return feature_maps, prediction_map
class MultiScaleDiscriminator(nn.Layer):
"""
Multi-scale (scale) discriminator
"""
def __init__(self, scales=(), **kwargs):
super(MultiScaleDiscriminator, self).__init__()
self.scales = scales
self.discs = nn.LayerList()
self.nameList = []
for scale in scales:
self.discs.add_sublayer(
str(scale).replace('.', '-'), Discriminator(**kwargs))
self.nameList.append(str(scale).replace('.', '-'))
def forward(self, x, kp=None):
out_dict = {}
for scale, disc in zip(self.nameList, self.discs):
scale = str(scale).replace('-', '.')
key = 'prediction_' + scale
feature_maps, prediction_map = disc(x[key], kp)
out_dict['feature_maps_' + scale] = feature_maps
out_dict['prediction_map_' + scale] = prediction_map
return out_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import paddle
from .base_model import BaseModel
from .builder import MODELS
from .discriminators.builder import build_discriminator
from .generators.builder import build_generator
from ..modules.init import init_weights
from ..solver import build_optimizer
from paddle.optimizer.lr import MultiStepDecay
from ..modules.init import reset_parameters, uniform_
import paddle.nn as nn
import numpy as np
from paddle.utils import try_import
import paddle.nn.functional as F
import cv2
def init_weight(net):
def reset_func(m):
if isinstance(m, (nn.BatchNorm, nn.BatchNorm2D, nn.SyncBatchNorm)):
m.weight = uniform_(m.weight, 0, 1)
elif hasattr(m, 'weight') and hasattr(m, 'bias'):
reset_parameters(m)
net.apply(reset_func)
@MODELS.register()
class FirstOrderModel(BaseModel):
""" This class implements the FirstOrderMotion model, FirstOrderMotion paper:
https://proceedings.neurips.cc/paper/2019/file/31c0b36aef265d9221af80872ceb62f9-Paper.pdf.
"""
def __init__(self,
common_params,
train_params,
generator,
discriminator=None):
super(FirstOrderModel, self).__init__()
# def local var
self.input_data = None
self.generated = None
self.losses_generator = None
self.train_params = train_params
# define networks
generator_cfg = generator
generator_cfg.update({'common_params': common_params})
generator_cfg.update({'train_params': train_params})
generator_cfg.update(
{'dis_scales': discriminator.discriminator_cfg.scales})
self.Gen_Full = build_generator(generator_cfg)
discriminator_cfg = discriminator
discriminator_cfg.update({'common_params': common_params})
discriminator_cfg.update({'train_params': train_params})
self.Dis = build_discriminator(discriminator_cfg)
self.visualizer = Visualizer()
if isinstance(self.Gen_Full, paddle.DataParallel):
self.nets['kp_detector'] = self.Gen_Full._layers.kp_extractor
self.nets['generator'] = self.Gen_Full._layers.generator
self.nets['discriminator'] = self.Dis._layers.discriminator
else:
self.nets['kp_detector'] = self.Gen_Full.kp_extractor
self.nets['generator'] = self.Gen_Full.generator
self.nets['discriminator'] = self.Dis.discriminator
# init params
init_weight(self.nets['kp_detector'])
init_weight(self.nets['generator'])
init_weight(self.nets['discriminator'])
def setup_lr_schedulers(self, lr_cfg):
self.kp_lr = MultiStepDecay(learning_rate=lr_cfg['lr_kp_detector'],
milestones=lr_cfg['epoch_milestones'],
gamma=0.1)
self.gen_lr = MultiStepDecay(learning_rate=lr_cfg['lr_generator'],
milestones=lr_cfg['epoch_milestones'],
gamma=0.1)
self.dis_lr = MultiStepDecay(learning_rate=lr_cfg['lr_discriminator'],
milestones=lr_cfg['epoch_milestones'],
gamma=0.1)
self.lr_scheduler = {
"kp_lr": self.kp_lr,
"gen_lr": self.gen_lr,
"dis_lr": self.dis_lr
}
def setup_optimizers(self, lr_cfg, optimizer):
# define loss functions
self.losses = {}
self.optimizers['optimizer_KP'] = build_optimizer(
optimizer,
self.kp_lr,
parameters=self.nets['kp_detector'].parameters())
self.optimizers['optimizer_Gen'] = build_optimizer(
optimizer,
self.gen_lr,
parameters=self.nets['generator'].parameters())
self.optimizers['optimizer_Dis'] = build_optimizer(
optimizer,
self.dis_lr,
parameters=self.nets['discriminator'].parameters())
def setup_input(self, input):
self.input_data = input
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.losses_generator, self.generated = \
self.Gen_Full(self.input_data.copy(), self.nets['discriminator'])
self.visual_items['driving_source_gen'] = self.visualizer.visualize(
self.input_data['driving'].detach(),
self.input_data['source'].detach(), self.generated)
def backward_G(self):
loss_values = [val.mean() for val in self.losses_generator.values()]
loss = paddle.add_n(loss_values)
self.losses = dict(zip(self.losses_generator.keys(), loss_values))
loss.backward()
def backward_D(self):
losses_discriminator = self.Dis(self.input_data.copy(), self.generated)
loss_values = [val.mean() for val in losses_discriminator.values()]
loss = paddle.add_n(loss_values)
loss.backward()
self.losses.update(dict(zip(losses_discriminator.keys(), loss_values)))
def train_iter(self, optimizers=None):
self.forward()
# update G
self.set_requires_grad(self.nets['discriminator'], False)
self.optimizers['optimizer_KP'].clear_grad()
self.optimizers['optimizer_Gen'].clear_grad()
self.backward_G()
outs = {}
self.optimizers['optimizer_KP'].step()
self.optimizers['optimizer_Gen'].step()
# update D
if self.train_params['loss_weights']['generator_gan'] != 0:
self.set_requires_grad(self.nets['discriminator'], True)
self.optimizers['optimizer_Dis'].clear_grad()
self.backward_D()
self.optimizers['optimizer_Dis'].step()
def test_iter(self, metrics=None):
self.nets['kp_detector'].eval()
self.nets['generator'].eval()
loss_list = []
with paddle.no_grad():
kp_source = self.nets['kp_detector'](self.input_data['video'][:, :,
0])
for frame_idx in range(self.input_data['video'].shape[2]):
source = self.input_data['video'][:, :, 0]
driving = self.input_data['video'][:, :, frame_idx]
kp_driving = self.nets['kp_detector'](driving)
out = self.nets['generator'](source,
kp_source=kp_source,
kp_driving=kp_driving)
loss = paddle.abs(out['prediction'] -
driving).mean().cpu().numpy()
loss_list.append(loss)
print("Reconstruction loss: %s" % np.mean(loss_list))
self.nets['kp_detector'].train()
self.nets['generator'].train()
class Visualizer:
def __init__(self, kp_size=3, draw_border=False, colormap='gist_rainbow'):
plt = try_import('matplotlib.pyplot')
self.kp_size = kp_size
self.draw_border = draw_border
self.colormap = plt.get_cmap(colormap)
def draw_image_with_kp(self, image, kp_array):
image = np.copy(image)
spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
kp_array = spatial_size * (kp_array + 1) / 2
num_kp = kp_array.shape[0]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = (image * 255).astype(np.uint8)
for kp_ind, kp in enumerate(kp_array):
color = cv2.applyColorMap(
np.array(kp_ind / num_kp * 255).astype(np.uint8),
cv2.COLORMAP_JET)[0][0]
color = (int(color[0]), int(color[1]), int(color[2]))
image = cv2.circle(image, (int(kp[1]), int(kp[0])), self.kp_size,
color, 3)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR).astype('float32') / 255.0
return image
def create_image_column_with_kp(self, images, kp):
image_array = np.array(
[self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
return self.create_image_column(image_array)
def create_image_column(self, images, draw_border=False):
if draw_border:
images = np.copy(images)
images[:, :, [0, -1]] = (1, 1, 1)
images[:, :, [0, -1]] = (1, 1, 1)
return np.concatenate(list(images), axis=0)
def create_image_grid(self, *args):
out = []
for arg in args:
if type(arg) == tuple:
out.append(self.create_image_column_with_kp(arg[0], arg[1]))
else:
out.append(self.create_image_column(arg))
return np.concatenate(out, axis=1)
def visualize(self, driving, source, out):
images = []
# Source image with keypoints
source = source.cpu().numpy()
kp_source = out['kp_source']['value'].cpu().numpy()
source = np.transpose(source, [0, 2, 3, 1])
images.append((source, kp_source))
# Equivariance visualization
if 'transformed_frame' in out:
transformed = out['transformed_frame'].cpu().numpy()
transformed = np.transpose(transformed, [0, 2, 3, 1])
transformed_kp = out['transformed_kp']['value'].cpu().numpy()
images.append((transformed, transformed_kp))
# Driving image with keypoints
kp_driving = out['kp_driving']['value'].cpu().numpy()
driving = driving.cpu().numpy()
driving = np.transpose(driving, [0, 2, 3, 1])
images.append((driving, kp_driving))
# Deformed image
if 'deformed' in out:
deformed = out['deformed'].cpu().numpy()
deformed = np.transpose(deformed, [0, 2, 3, 1])
images.append(deformed)
# Result with and without keypoints
prediction = out['prediction'].cpu().numpy()
prediction = np.transpose(prediction, [0, 2, 3, 1])
if 'kp_norm' in out:
kp_norm = out['kp_norm']['value'].cpu().numpy()
images.append((prediction, kp_norm))
images.append(prediction)
## Occlusion map
if 'occlusion_map' in out:
occlusion_map = out['occlusion_map'].cpu().tile([1, 3, 1, 1])
occlusion_map = F.interpolate(occlusion_map,
size=source.shape[1:3]).numpy()
occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
images.append(occlusion_map)
# Deformed images according to each individual transform
if 'sparse_deformed' in out:
full_mask = []
for i in range(out['sparse_deformed'].shape[1]):
image = out['sparse_deformed'][:, i].cpu()
image = F.interpolate(image, size=source.shape[1:3])
mask = out['mask'][:, i:(i + 1)].cpu().tile([1, 3, 1, 1])
mask = F.interpolate(mask, size=source.shape[1:3])
image = np.transpose(image.numpy(), (0, 2, 3, 1))
mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
if i != 0:
color = np.array(
self.colormap(
(i - 1) /
(out['sparse_deformed'].shape[1] - 1)))[:3]
else:
color = np.array((0, 0, 0))
color = color.reshape((1, 1, 1, 3))
images.append(image)
if i != 0:
images.append(mask * color)
else:
images.append(mask)
full_mask.append(mask * color)
images.append(sum(full_mask))
image = self.create_image_grid(*images)
image = (255 * image).astype(np.uint8)
return image
......@@ -28,4 +28,4 @@ from .generator_pixel2style2pixel import Pixel2Style2Pixel
from .drn import DRNGenerator
from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN
from .edvr import EDVRNet
from .generator_firstorder import FirstOrderGenerator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
from ppgan.models.generators.builder import GENERATORS
from .occlusion_aware import OcclusionAwareGenerator
from ...modules.first_order import make_coordinate_grid, ImagePyramide, detach_kp
from ...modules.keypoint_detector import KPDetector
import paddle.vision.models.vgg as vgg
from ppgan.utils.download import get_path_from_url
@GENERATORS.register()
class FirstOrderGenerator(nn.Layer):
"""
Args:
kp_detector_cfg:
temperature (flost): parameter of softmax
block_expansion (int): block_expansion * (2**i) output features for each block i
max_features (int): input features cannot larger than max_features for encoding images
num_blocks (int): number of blocks for encoding images
generator_cfg:
block_expansion (int): block_expansion * (2**i) output features for each block i
max_features (int): input features cannot larger than max_features for encoding images
num_down_blocks (int): Downsampling block number for use in encoder.
num_bottleneck_blocks (int): block number for use in decoder.
estimate_occlusion_map (bool): whether to extimate occlusion_map
common_params:
num_kp (int): number of keypoints
num_channels (int): image channels
estimate_jacobian (bool): whether to estimate jacobian values of keypoints
train_params:
transform_params: transform keypoints and its jacobians
scale: extract the features of image pyramids
loss_weights: weight of [generator, discriminator, feature_matching, perceptual,
equivariance_value, equivariance_jacobian]
"""
def __init__(self, generator_cfg, kp_detector_cfg, common_params,
train_params, dis_scales):
super(FirstOrderGenerator, self).__init__()
self.kp_extractor = KPDetector(**kp_detector_cfg, **common_params)
self.generator = OcclusionAwareGenerator(**generator_cfg,
**common_params)
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = dis_scales
self.pyramid = ImagePyramide(self.scales, self.generator.num_channels)
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = VGG19()
def forward(self, x, discriminator):
kp_source = self.kp_extractor(x['source'])
kp_driving = self.kp_extractor(x['driving'])
generated = self.generator(x['source'],
kp_source=kp_source,
kp_driving=kp_driving)
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
loss_values = {}
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'])
# VGG19 perceptual Loss
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = paddle.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_values['perceptual'] = value_total
# Generator Loss
if self.loss_weights['generator_gan'] != 0:
discriminator_maps_generated = discriminator(
pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = discriminator(pyramide_real,
kp=detach_kp(kp_driving))
value_total = 0
for scale in self.disc_scales:
key = 'prediction_map_%s' % scale
value = ((1 - discriminator_maps_generated[key])**2).mean()
value_total += self.loss_weights['generator_gan'] * value
loss_values['gen_gan'] = value_total
# Feature matching Loss
if sum(self.loss_weights['feature_matching']) != 0:
value_total = 0
for scale in self.disc_scales:
key = 'feature_maps_%s' % scale
for i, (a, b) in enumerate(
zip(discriminator_maps_real[key],
discriminator_maps_generated[key])):
if self.loss_weights['feature_matching'][i] == 0:
continue
value = paddle.abs(a - b).mean()
value_total += self.loss_weights['feature_matching'][
i] * value
loss_values['feature_matching'] = value_total
if (self.loss_weights['equivariance_value'] +
self.loss_weights['equivariance_jacobian']) != 0:
transform = Transform(x['driving'].shape[0],
**self.train_params['transform_params'])
transformed_frame = transform.transform_frame(x['driving'])
transformed_kp = self.kp_extractor(transformed_frame)
generated['transformed_frame'] = transformed_frame
generated['transformed_kp'] = transformed_kp
# Value loss part
if self.loss_weights['equivariance_value'] != 0:
value = paddle.abs(
kp_driving['value'] -
transform.warp_coordinates(transformed_kp['value'])).mean()
loss_values['equivariance_value'] = self.loss_weights[
'equivariance_value'] * value
# jacobian loss part
if self.loss_weights['equivariance_jacobian'] != 0:
jacobian_transformed = paddle.matmul(
*broadcast(transform.jacobian(transformed_kp['value']),
transformed_kp['jacobian']))
normed_driving = paddle.inverse(kp_driving['jacobian'])
normed_transformed = jacobian_transformed
value = paddle.matmul(
*broadcast(normed_driving, normed_transformed))
eye = paddle.tensor.eye(2, dtype='float32').reshape(
(1, 1, 2, 2))
eye = paddle.tile(eye, [1, value.shape[1], 1, 1])
value = paddle.abs(eye - value).mean()
loss_values['equivariance_jacobian'] = self.loss_weights[
'equivariance_jacobian'] * value
return loss_values, generated
class VGG19(nn.Layer):
"""
Vgg19 network for perceptual loss. See Sec 3.3.
"""
def __init__(self, requires_grad=False):
super(VGG19, self).__init__()
pretrained_url = 'https://paddlegan.bj.bcebos.com/models/vgg19.pdparams'
weight_path = get_path_from_url(pretrained_url)
state_dict = paddle.load(weight_path)
_vgg = getattr(vgg, 'vgg19')()
_vgg.load_dict(state_dict)
vgg_pretrained_features = _vgg.features
self.slice1 = paddle.nn.Sequential()
self.slice2 = paddle.nn.Sequential()
self.slice3 = paddle.nn.Sequential()
self.slice4 = paddle.nn.Sequential()
self.slice5 = paddle.nn.Sequential()
for x in range(2):
self.slice1.add_sublayer(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_sublayer(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_sublayer(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_sublayer(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_sublayer(str(x), vgg_pretrained_features[x])
self.register_buffer(
'mean',
paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]))
# the std is for image with range [-1, 1]
self.register_buffer(
'std',
paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]))
if not requires_grad:
for param in self.parameters():
param.stop_gradient = True
def forward(self, x):
x = (x - self.mean) / self.std
h_relu1 = self.slice1(x)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class Transform:
"""
Random tps transformation for equivariance constraints. See Sec 3.3
"""
def __init__(self, bs, **kwargs):
noise = paddle.distribution.Normal(loc=[0],
scale=[kwargs['sigma_affine']
]).sample([bs, 2, 3])
noise = noise.reshape((bs, 2, 3))
self.theta = noise + paddle.tensor.eye(2, 3, dtype='float32').reshape(
(1, 2, 3))
self.bs = bs
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
self.tps = True
self.control_points = make_coordinate_grid(
(kwargs['points_tps'], kwargs['points_tps'])).unsqueeze(0)
buf = paddle.distribution.Normal(
loc=[0], scale=[kwargs['sigma_tps']
]).sample([bs, 1, kwargs['points_tps']**2])
self.control_params = buf.reshape((bs, 1, kwargs['points_tps']**2))
else:
self.tps = False
def transform_frame(self, frame):
grid = make_coordinate_grid(frame.shape[2:], 'float32').unsqueeze(0)
grid = grid.reshape((1, frame.shape[2] * frame.shape[3], 2))
grid = self.warp_coordinates(grid).reshape(
(self.bs, frame.shape[2], frame.shape[3], 2))
return F.grid_sample(frame,
grid,
mode='bilinear',
padding_mode='reflection',
align_corners=True)
def warp_coordinates(self, coordinates):
theta = self.theta.astype('float32')
theta = theta.unsqueeze(1)
coordinates = coordinates.unsqueeze(-1)
# If x1:(1, 5, 2, 2), x2:(10, 100, 2, 1)
# torch.matmul can broadcast x1, x2 to (10, 100, ...)
# In PDPD, it should be done manually
theta_part_a = theta[:, :, :, :2]
theta_part_b = theta[:, :, :, 2:]
transformed = paddle.fluid.layers.matmul(
*broadcast(theta_part_a, coordinates)) + theta_part_b #M*p + m0
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.astype('float32')
control_params = self.control_params.astype('float32')
distances = coordinates.reshape(
(coordinates.shape[0], -1, 1, 2)) - control_points.reshape(
(1, 1, -1, 2))
distances = distances.abs().sum(-1)
result = distances * distances
result = result * paddle.log(distances + 1e-6)
result = result * control_params
result = result.sum(2).reshape((self.bs, coordinates.shape[1], 1))
transformed = transformed + result
return transformed
def jacobian(self, coordinates):
new_coordinates = self.warp_coordinates(coordinates)
assert len(new_coordinates.shape) == 3
grad_x = paddle.grad(new_coordinates[:, :, 0].sum(),
coordinates,
create_graph=True)
grad_y = paddle.grad(new_coordinates[:, :, 1].sum(),
coordinates,
create_graph=True)
jacobian = paddle.concat(
[grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], axis=-2)
return jacobian
def broadcast(x, y):
"""
Broadcast before matmul
"""
if len(x.shape) != len(y.shape):
raise ValueError(x.shape, '!=', y.shape)
*dim_x, _, _ = x.shape
*dim_y, _, _ = y.shape
max_shape = np.max(np.stack([dim_x, dim_y], axis=0), axis=0)
x_bc = paddle.broadcast_to(x, (*max_shape, x.shape[-2], x.shape[-1]))
y_bc = paddle.broadcast_to(y, (*max_shape, y.shape[-2], y.shape[-1]))
return x_bc, y_bc
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import paddle
from paddle import nn
import paddle.nn.functional as F
......@@ -98,7 +100,11 @@ class OcclusionAwareGenerator(nn.Layer):
mode='bilinear',
align_corners=False)
deformation = deformation.transpose([0, 2, 3, 1])
return F.grid_sample(inp, deformation, align_corners=False)
return F.grid_sample(inp,
deformation,
mode='bilinear',
padding_mode='zeros',
align_corners=True)
def forward(self, source_image, kp_driving, kp_source):
# Encoding (downsampling) part
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......@@ -96,9 +98,15 @@ class DenseMotionNetwork(nn.Layer):
jacobian = paddle.matmul(kp_source['jacobian'],
paddle.inverse(kp_driving['jacobian']))
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
jacobian = paddle.tile(jacobian, [1, 1, h, w, 1, 1])
coordinate_grid = paddle.matmul(jacobian,
# Todo: fix bug of paddle.tile
p_jacobian = jacobian.reshape([bs, self.num_kp, 1, 1, 4])
paddle_jacobian = paddle.tile(p_jacobian, [1, 1, h, w, 1])
paddle_jacobian = paddle_jacobian.reshape(
[bs, self.num_kp, h, w, 2, 2])
coordinate_grid = paddle.matmul(paddle_jacobian,
coordinate_grid.unsqueeze(-1))
coordinate_grid = coordinate_grid.squeeze(-1)
driving_to_source = coordinate_grid + kp_source['value'].reshape(
......@@ -125,7 +133,9 @@ class DenseMotionNetwork(nn.Layer):
(bs * (self.num_kp + 1), h, w, -1))
sparse_deformed = F.grid_sample(source_repeat,
sparse_motions,
align_corners=False)
mode='bilinear',
padding_mode='zeros',
align_corners=True)
sparse_deformed = sparse_deformed.reshape(
(bs, self.num_kp + 1, -1, h, w))
return sparse_deformed
......
......@@ -12,11 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def SyncBatchNorm(*args, **kwargs):
"""In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm instead"""
if paddle.get_device() == 'cpu':
return nn.BatchNorm(*args, **kwargs)
else:
return nn.SyncBatchNorm(*args, **kwargs)
class ImagePyramide(nn.Layer):
"""
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
self.downs = paddle.nn.LayerList()
self.name_list = []
for scale in scales:
self.downs.add_sublayer(
str(scale).replace('.', '-'),
AntiAliasInterpolation2d(num_channels, scale))
self.name_list.append(str(scale).replace('.', '-'))
def forward(self, x):
out_dict = {}
for scale, down_module in zip(self.name_list, self.downs):
out_dict['prediction_' +
str(scale).replace('-', '.')] = down_module(x)
return out_dict
def detach_kp(kp):
return {key: value.detach() for key, value in kp.items()}
def kp2gaussian(kp, spatial_size, kp_variance):
"""
Transform a keypoint into gaussian like representation
......@@ -26,9 +62,9 @@ def kp2gaussian(kp, spatial_size, kp_variance):
coordinate_grid = make_coordinate_grid(spatial_size, mean.dtype)
number_of_leading_dimensions = len(mean.shape) - 1
shape = (1, ) * number_of_leading_dimensions + tuple(coordinate_grid.shape)
coordinate_grid = coordinate_grid.reshape([*shape])
repeats = tuple(mean.shape[:number_of_leading_dimensions]) + (1, 1, 1)
coordinate_grid = paddle.tile(coordinate_grid, [*repeats])
coordinate_grid = coordinate_grid.reshape(shape)
coordinate_grid = coordinate_grid.tile(repeats)
# Preprocess kp shape
shape = tuple(mean.shape[:number_of_leading_dimensions]) + (1, 1, 2)
......@@ -41,7 +77,7 @@ def kp2gaussian(kp, spatial_size, kp_variance):
return out
def make_coordinate_grid(spatial_size, type):
def make_coordinate_grid(spatial_size, type='float32'):
"""
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
"""
......@@ -74,8 +110,8 @@ class ResBlock2d(nn.Layer):
out_channels=in_features,
kernel_size=kernel_size,
padding=padding)
self.norm1 = nn.BatchNorm2D(in_features)
self.norm2 = nn.BatchNorm2D(in_features)
self.norm1 = SyncBatchNorm(in_features)
self.norm2 = SyncBatchNorm(in_features)
def forward(self, x):
out = self.norm1(x)
......@@ -105,7 +141,7 @@ class UpBlock2d(nn.Layer):
kernel_size=kernel_size,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm2D(out_features)
self.norm = SyncBatchNorm(out_features)
def forward(self, x):
out = F.interpolate(x, scale_factor=2)
......@@ -131,7 +167,7 @@ class DownBlock2d(nn.Layer):
kernel_size=kernel_size,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm2D(out_features)
self.norm = SyncBatchNorm(out_features)
self.pool = nn.AvgPool2D(kernel_size=(2, 2))
def forward(self, x):
......@@ -158,7 +194,7 @@ class SameBlock2d(nn.Layer):
kernel_size=kernel_size,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm2D(out_features)
self.norm = SyncBatchNorm(out_features)
def forward(self, x):
out = self.conv(x)
......@@ -267,7 +303,7 @@ class AntiAliasInterpolation2d(nn.Layer):
[paddle.arange(size, dtype='float32') for size in kernel_size])
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= paddle.exp(-(mgrid - mean)**2 / (2 * std**2))
kernel *= paddle.exp(-(mgrid - mean)**2 / (2 * std**2 + 1e-9))
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / paddle.sum(kernel)
......@@ -285,6 +321,11 @@ class AntiAliasInterpolation2d(nn.Layer):
out = F.pad(input, [self.ka, self.kb, self.ka, self.kb])
out = F.conv2d(out, weight=self.weight, groups=self.groups)
out = F.interpolate(out, scale_factor=[self.scale, self.scale])
out.stop_gradient = False
inv_scale = 1 / self.scale
int_inv_scale = int(inv_scale)
assert (inv_scale == int_inv_scale)
out = out[:, :, ::int_inv_scale, ::int_inv_scale]
# patch end
return out
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......@@ -51,8 +53,11 @@ class KPDetector(nn.Layer):
out_channels=4 * self.num_jacobian_maps,
kernel_size=(7, 7),
padding=pad)
# self.jacobian.weight.data.zero_()
# self.jacobian.bias.data.copy_(paddle.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype='float32'))
self.jacobian.weight.set_value(
paddle.zeros(self.jacobian.weight.shape, dtype='float32'))
self.jacobian.bias.set_value(
paddle.to_tensor([1, 0, 0, 1] *
self.num_jacobian_maps).astype('float32'))
else:
self.jacobian = None
......@@ -68,26 +73,21 @@ class KPDetector(nn.Layer):
"""
shape = heatmap.shape
heatmap = heatmap.unsqueeze(-1)
grid = make_coordinate_grid(shape[2:],
heatmap.dtype).unsqueeze(0).unsqueeze(0)
grid = make_coordinate_grid(shape[2:]).unsqueeze([0, 1])
value = (heatmap * grid).sum(axis=(2, 3))
kp = {'value': value}
return kp
def forward(self, x):
if self.scale_factor != 1:
x = self.down(x)
feature_map = self.predictor(x)
prediction = self.kp(feature_map)
final_shape = prediction.shape
heatmap = prediction.reshape([final_shape[0], final_shape[1], -1])
heatmap = F.softmax(heatmap / self.temperature, axis=2)
heatmap = heatmap.reshape([*final_shape])
heatmap = heatmap.reshape(final_shape)
out = self.gaussian2kp(heatmap)
if self.jacobian is not None:
......@@ -97,7 +97,7 @@ class KPDetector(nn.Layer):
final_shape[3]
])
heatmap = heatmap.unsqueeze(2)
heatmap = paddle.tile(heatmap, [1, 1, 4, 1, 1])
jacobian = heatmap * jacobian_map
jacobian = jacobian.reshape([final_shape[0], final_shape[1], 4, -1])
jacobian = jacobian.sum(axis=-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册