From d913d8d79891c2859b356847e2d442efb7f1a6c4 Mon Sep 17 00:00:00 2001 From: lzzyzlbb <287246233@qq.com> Date: Wed, 21 Jul 2021 18:08:57 +0800 Subject: [PATCH] 1.add face enhancement. 2.fix edge problem (#367) * add stargan pretrain model * 1.add face enhancement. 2.fix edge problem * 1.add face enhancement. 2.fix edge problem * 1.add face enhancement. 2.fix edge problem * 1.add face enhancement. 2.fix edge problem * 1.add face enhancement. 2.fix edge problem * 1.add face enhancement. 2.fix edge problem --- applications/tools/first-order-demo.py | 15 +++- applications/tools/wav2lip.py | 9 +- docs/en_US/tutorials/motion_driving.md | 12 ++- docs/en_US/tutorials/wav2lip.md | 8 +- docs/zh_CN/tutorials/motion_driving.md | 12 ++- docs/zh_CN/tutorials/wav2lip.md | 8 +- ppgan/apps/first_order_predictor.py | 51 +++++++---- ppgan/apps/wav2lip_predictor.py | 9 +- ppgan/faceutils/face_enhancement/__init__.py | 15 ++++ .../face_enhancement/face_enhance.py | 78 ++++++++++++++++ ppgan/models/generators/__init__.py | 1 + .../models/generators/generator_styleganv2.py | 64 +++++++++----- ppgan/models/generators/gpen.py | 88 +++++++++++++++++++ ppgan/models/generators/occlusion_aware.py | 25 +++++- 14 files changed, 350 insertions(+), 45 deletions(-) create mode 100644 ppgan/faceutils/face_enhancement/__init__.py create mode 100644 ppgan/faceutils/face_enhancement/face_enhance.py create mode 100644 ppgan/models/generators/gpen.py diff --git a/applications/tools/first-order-demo.py b/applications/tools/first-order-demo.py index cfcadbe..588e9c3 100644 --- a/applications/tools/first-order-demo.py +++ b/applications/tools/first-order-demo.py @@ -73,8 +73,19 @@ parser.add_argument("--image_size", type=int, default=256, help="size of image") +parser.add_argument("--batch_size", + dest="batch_size", + type=int, + default=1, + help="Batch size for fom model") +parser.add_argument( + "--face_enhancement", + dest="face_enhancement", + action="store_true", + help="use face enhance for face") parser.set_defaults(relative=False) parser.set_defaults(adapt_scale=False) +parser.set_defaults(face_enhancement=False) if __name__ == "__main__": args = parser.parse_args() @@ -92,5 +103,7 @@ if __name__ == "__main__": ratio=args.ratio, face_detector=args.face_detector, multi_person=args.multi_person, - image_size=args.image_size) + image_size=args.image_size, + batch_size=args.batch_size, + face_enhancement=args.face_enhancement) predictor.run(args.source_image, args.driving_video) diff --git a/applications/tools/wav2lip.py b/applications/tools/wav2lip.py index 5fd472e..ac655a6 100644 --- a/applications/tools/wav2lip.py +++ b/applications/tools/wav2lip.py @@ -103,6 +103,12 @@ parser.add_argument( type=str, default='sfd', help="face detector to be used, can choose s3fd or blazeface") +parser.add_argument( + "--face_enhancement", + dest="face_enhancement", + action="store_true", + help="use face enhance for face") +parser.set_defaults(face_enhancement=False) if __name__ == "__main__": args = parser.parse_args() @@ -120,5 +126,6 @@ if __name__ == "__main__": box = args.box, rotate = args.rotate, nosmooth = args.nosmooth, - face_detector = args.face_detector) + face_detector = args.face_detector, + face_enhancement = args.face_enhancement) predictor.run(args.face, args.audio, args.outfile) diff --git a/docs/en_US/tutorials/motion_driving.md b/docs/en_US/tutorials/motion_driving.md index ae375ce..5d4e38d 100644 --- a/docs/en_US/tutorials/motion_driving.md +++ b/docs/en_US/tutorials/motion_driving.md @@ -33,7 +33,8 @@ python -u tools/first-order-demo.py \ --source_image ../docs/imgs/fom_source_image.png \ --ratio 0.4 \ --relative --adapt_scale \ - --image_size 512 + --image_size 512 \ + --face_enhancement ``` - multi face: @@ -56,7 +57,16 @@ python -u tools/first-order-demo.py \ - ratio: The pasted face percentage of generated image, this parameter should be adjusted in the case of multi-person image in which the adjacent faces are close. The defualt value is 0.4 and the range is [0.4, 0.5]. - image_size: The image size of the face. Default is 256 - multi_person: There are multi faces in the images. Default means only one face in the image +- face_enhancement: enhance the face, default is False ``` +result of face_enhancement: +
+ +
+
+ +
+ ### 2 Training **Datasets:** diff --git a/docs/en_US/tutorials/wav2lip.md b/docs/en_US/tutorials/wav2lip.md index 7c7ffbf..4411b54 100644 --- a/docs/en_US/tutorials/wav2lip.md +++ b/docs/en_US/tutorials/wav2lip.md @@ -11,13 +11,19 @@ Runing the following command to complete the lip-syning task. The output is the ``` cd applications -python tools/wav2lip.py --face ../docs/imgs/mona7s.mp4 --audio ../docs/imgs/guangquan.m4a --outfile pp_guangquan_mona7s.mp4 +python tools/wav2lip.py \ + --face ../docs/imgs/mona7s.mp4 \ + --audio ../docs/imgs/guangquan.m4a \ + --outfile pp_guangquan_mona7s.mp4 \ + --face_enhancement ``` **params:** - face: path of the input image or video file including faces. - audio: path of the input audio file, format can be `.wav`, `.mp3`, `.m4a`. It can be any file supported by `FFMPEG` containing audio data. +- outfile: result video of wav2lip +- face_enhancement: enhance the face, default is False ### 2.2 Training 1. Our model are trained on LRS2. See [here](https://github.com/Rudrabha/Wav2Lip#training-on-datasets-other-than-lrs2) for a few suggestions regarding training on other datasets. diff --git a/docs/zh_CN/tutorials/motion_driving.md b/docs/zh_CN/tutorials/motion_driving.md index af2819b..0dadb1e 100644 --- a/docs/zh_CN/tutorials/motion_driving.md +++ b/docs/zh_CN/tutorials/motion_driving.md @@ -40,7 +40,8 @@ python -u tools/first-order-demo.py \ --source_image ../docs/imgs/fom_source_image.png \ --ratio 0.4 \ --relative --adapt_scale \ - --image_size 512 + --image_size 512 \ + --face_enhancement ``` - 多人脸: ``` @@ -60,7 +61,16 @@ python -u tools/first-order-demo.py \ - ratio: 贴回驱动生成的人脸区域占原图的比例, 用户需要根据生成的效果调整该参数,尤其对于多人脸距离比较近的情况下需要调整改参数, 默认为0.4,调整范围是[0.4, 0.5] - image_size: 图片人脸大小,默认为256 - multi_person: 表示图片中有多张人脸,不加则默认为单人脸 +- face_enhancement: 添加人脸增强,默认为false ``` +添加人脸增强对比如下: +
+ +
+
+ +
+ ### 2 训练 **数据集:** diff --git a/docs/zh_CN/tutorials/wav2lip.md b/docs/zh_CN/tutorials/wav2lip.md index d2d2f6d..900a12b 100644 --- a/docs/zh_CN/tutorials/wav2lip.md +++ b/docs/zh_CN/tutorials/wav2lip.md @@ -13,11 +13,17 @@ Wav2Lip实现的是视频人物根据输入音频生成与语音同步的人物 ``` cd applications -python tools/wav2lip.py --face ../docs/imgs/mona7s.mp4 --audio ../docs/imgs/guangquan.m4a --outfile pp_guangquan_mona7s.mp4 +python tools/wav2lip.py \ + --face ../docs/imgs/mona7s.mp4 \ + --audio ../docs/imgs/guangquan.m4a \ + --outfile pp_guangquan_mona7s.mp4 + --face_enhancement ``` **参数说明:** - face: 视频或图片,视频或图片中的人物唇形将根据音频进行唇形合成,以和音频同步 - audio: 驱动唇形合成的音频,视频中的人物将根据此音频进行唇形合成 +- outfile: 合成的视频 +- face_enhancement: 添加人脸增强,默认为false ### 2.2 训练 1. 我们的模型是基于LRS2数据集训练的。可以参考[这里](https://github.com/Rudrabha/Wav2Lip#training-on-datasets-other-than-lrs2)获得在其它训练集上进行训练的一些建议。 diff --git a/ppgan/apps/first_order_predictor.py b/ppgan/apps/first_order_predictor.py index 876fa58..a225d20 100644 --- a/ppgan/apps/first_order_predictor.py +++ b/ppgan/apps/first_order_predictor.py @@ -47,7 +47,9 @@ class FirstOrderPredictor(BasePredictor): filename='result.mp4', face_detector='sfd', multi_person=False, - image_size=256): + image_size=256, + face_enhancement=False, + batch_size=1): if config is not None and isinstance(config, str): with open(config) as f: self.cfg = yaml.load(f, Loader=yaml.SafeLoader) @@ -107,6 +109,11 @@ class FirstOrderPredictor(BasePredictor): self.generator, self.kp_detector = self.load_checkpoints( self.cfg, self.weight_path) self.multi_person = multi_person + self.face_enhancement = face_enhancement + self.batch_size = batch_size + if face_enhancement: + from ppgan.faceutils.face_enhancement import FaceEnhancement + self.faceenhancer = FaceEnhancement(batch_size=batch_size) def read_img(self, path): img = imageio.imread(path) @@ -177,7 +184,7 @@ class FirstOrderPredictor(BasePredictor): face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]] face_image = cv2.resize(face_image, (self.image_size, self.image_size)) / 255.0 predictions = get_prediction(face_image) - results.append({'rec': rec, 'predict': predictions}) + results.append({'rec': rec, 'predict': [predictions[i] for i in range(predictions.shape[0])]}) if len(bboxes) == 1 or not self.multi_person: break out_frame = [] @@ -188,7 +195,7 @@ class FirstOrderPredictor(BasePredictor): x1, y1, x2, y2, _ = result['rec'] h = y2 - y1 w = x2 - x1 - out = result['predict'][i] * 255.0 + out = result['predict'][i] out = cv2.resize(out.astype(np.uint8), (x2 - x1, y2 - y1)) if len(results) == 1: frame[y1:y2, x1:x2] = out @@ -212,7 +219,7 @@ class FirstOrderPredictor(BasePredictor): generator = OcclusionAwareGenerator( **config['model']['generator']['generator_cfg'], - **config['model']['common_params']) + **config['model']['common_params'], inference=True) kp_detector = KPDetector( **config['model']['generator']['kp_detector_cfg'], @@ -241,14 +248,23 @@ class FirstOrderPredictor(BasePredictor): np.float32)).transpose([0, 3, 1, 2]) driving = paddle.to_tensor( - np.array(driving_video)[np.newaxis].astype( - np.float32)).transpose([0, 4, 1, 2, 3]) + np.array(driving_video).astype( + np.float32)).transpose([0, 3, 1, 2]) kp_source = kp_detector(source) - kp_driving_initial = kp_detector(driving[:, :, 0]) - - for frame_idx in tqdm(range(driving.shape[2])): - driving_frame = driving[:, :, frame_idx] + kp_driving_initial = kp_detector(driving[0:1]) + kp_source_batch = {} + kp_source_batch["value"] = paddle.tile(kp_source["value"], repeat_times=[self.batch_size,1,1]) + kp_source_batch["jacobian"] = paddle.tile(kp_source["jacobian"], repeat_times=[self.batch_size,1,1,1]) + source = paddle.tile(source, repeat_times=[self.batch_size,1,1,1]) + begin_idx = 0 + for frame_idx in tqdm(range(int(np.ceil(float(driving.shape[0]) / self.batch_size)))): + frame_num = min(self.batch_size, driving.shape[0] - begin_idx) + driving_frame = driving[begin_idx: begin_idx+frame_num] kp_driving = kp_detector(driving_frame) + kp_source_img = {} + kp_source_img["value"] = kp_source_batch["value"][0:frame_num] + kp_source_img["jacobian"] = kp_source_batch["jacobian"][0:frame_num] + kp_norm = normalize_kp( kp_source=kp_source, kp_driving=kp_driving, @@ -256,11 +272,16 @@ class FirstOrderPredictor(BasePredictor): use_relative_movement=relative, use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale) - out = generator(source, kp_source=kp_source, kp_driving=kp_norm) - - predictions.append( - np.transpose(out['prediction'].numpy(), [0, 2, 3, 1])[0]) - return predictions + + out = generator(source[0:frame_num], kp_source=kp_source_img, kp_driving=kp_norm) + img = np.transpose(out['prediction'].numpy(), [0, 2, 3, 1]) * 255.0 + + if self.face_enhancement: + img = self.faceenhancer.enhance_from_batch(img) + + predictions.append(img) + begin_idx += frame_num + return np.concatenate(predictions) def find_best_frame_func(self, source, driving): import face_alignment diff --git a/ppgan/apps/wav2lip_predictor.py b/ppgan/apps/wav2lip_predictor.py index 26a488c..152eedc 100644 --- a/ppgan/apps/wav2lip_predictor.py +++ b/ppgan/apps/wav2lip_predictor.py @@ -28,7 +28,8 @@ class Wav2LipPredictor(BasePredictor): box = [-1, -1, -1, -1], rotate = False, nosmooth = False, - face_detector = 'sfd'): + face_detector = 'sfd', + face_enhancement = False): self.img_size = 96 self.checkpoint_path = checkpoint_path self.static = static @@ -42,6 +43,10 @@ class Wav2LipPredictor(BasePredictor): self.rotate = rotate self.nosmooth = nosmooth self.face_detector = face_detector + self.face_enhancement = face_enhancement + if face_enhancement: + from ppgan.faceutils.face_enhancement import FaceEnhancement + self.faceenhancer = FaceEnhancement() makedirs('./temp', exist_ok=True) def get_smoothened_boxes(self, boxes, T): @@ -271,6 +276,8 @@ class Wav2LipPredictor(BasePredictor): for p, f, c in zip(pred, frames, coords): y1, y2, x1, x2 = c + if self.face_enhancement: + p = self.faceenhancer.enhance_from_image(p) p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) f[y1:y2, x1:x2] = p diff --git a/ppgan/faceutils/face_enhancement/__init__.py b/ppgan/faceutils/face_enhancement/__init__.py new file mode 100644 index 0000000..f429a82 --- /dev/null +++ b/ppgan/faceutils/face_enhancement/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .face_enhance import FaceEnhancement diff --git a/ppgan/faceutils/face_enhancement/face_enhance.py b/ppgan/faceutils/face_enhancement/face_enhance.py new file mode 100644 index 0000000..055fc0b --- /dev/null +++ b/ppgan/faceutils/face_enhancement/face_enhance.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import math +import cv2 +import numpy as np +from ppgan.utils.download import get_path_from_url +from ppgan.models.generators import GPEN +from ppgan.faceutils.face_detection.detection.blazeface.utils import * + +GPEN_weights = 'https://paddlegan.bj.bcebos.com/models/GPEN-512.pdparams' + + +class FaceEnhancement(object): + def __init__(self, + path_to_enhance=None, + size = 512, + batch_size=1 + ): + super(FaceEnhancement, self).__init__() + + # Initialise the face detector + if path_to_enhance is None: + model_weights_path = get_path_from_url(GPEN_weights) + model_weights = paddle.load(model_weights_path) + else: + model_weights = paddle.load(path_to_enhance) + + self.face_enhance = GPEN(size=512, style_dim=512, n_mlp=8) + self.face_enhance.load_dict(model_weights) + self.face_enhance.eval() + self.size = size + self.mask = np.zeros((512, 512), np.float32) + cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) + self.mask = paddle.tile(paddle.to_tensor(self.mask).unsqueeze(0).unsqueeze(-1), repeat_times=[batch_size,1,1,3]).numpy() + + + def enhance_from_image(self, img): + if isinstance(img, np.ndarray): + img, _ = resize_and_crop_image(img, 512) + img = paddle.to_tensor(img).transpose([2, 0, 1]) + else: + assert img.shape == [3, 512, 512] + return self.enhance_from_batch(img.unsqueeze(0))[0] + + def enhance_from_batch(self, img): + if isinstance(img, np.ndarray): + img_ori, _ = resize_and_crop_batch(img, 512) + img = paddle.to_tensor(img_ori).transpose([0, 3, 1, 2]) + else: + assert img.shape[1:] == [3, 512, 512] + img_ori = img.transpose([0, 2, 3, 1]).numpy() + img_t = (img/255. - 0.5) / 0.5 + + with paddle.no_grad(): + out, __ = self.face_enhance(img_t) + + image_tensor = out * 0.5 + 0.5 + image_tensor = image_tensor.transpose([0, 2, 3, 1]) # RGB + image_numpy = paddle.clip(image_tensor, 0, 1) * 255.0 + + out = image_numpy.astype(np.uint8).cpu().numpy() + return out * self.mask + (1-self.mask) * img_ori diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 8df2ec1..18d7a43 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -32,3 +32,4 @@ from .generator_firstorder import FirstOrderGenerator from .generater_lapstyle import DecoderNet, Encoder, RevisionNet from .basicvsr import BasicVSRNet from .mpr import MPRNet +from .gpen import GPEN diff --git a/ppgan/models/generators/generator_styleganv2.py b/ppgan/models/generators/generator_styleganv2.py index cabfe34..72a6c0a 100644 --- a/ppgan/models/generators/generator_styleganv2.py +++ b/ppgan/models/generators/generator_styleganv2.py @@ -136,18 +136,21 @@ class ModulatedConv2D(nn.Layer): class NoiseInjection(nn.Layer): - def __init__(self): + def __init__(self, is_concat=False): super().__init__() self.weight = self.create_parameter( (1, ), default_initializer=nn.initializer.Constant(0.0)) + self.is_concat = is_concat def forward(self, image, noise=None): if noise is None: batch, _, height, width = image.shape noise = paddle.randn((batch, 1, height, width)) - - return image + self.weight * noise + if self.is_concat: + return paddle.concat([image, self.weight * noise], axis=1) + else: + return image + self.weight * noise class ConstantInput(nn.Layer): @@ -175,6 +178,7 @@ class StyledConv(nn.Layer): upsample=False, blur_kernel=[1, 3, 3, 1], demodulate=True, + is_concat=False ): super().__init__() @@ -188,8 +192,8 @@ class StyledConv(nn.Layer): demodulate=demodulate, ) - self.noise = NoiseInjection() - self.activate = FusedLeakyReLU(out_channel) + self.noise = NoiseInjection(is_concat=is_concat) + self.activate = FusedLeakyReLU(out_channel*2 if is_concat else out_channel) def forward(self, input, style, noise=None): out = self.conv(input, style) @@ -240,6 +244,7 @@ class StyleGANv2Generator(nn.Layer): channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, + is_concat=False ): super().__init__() @@ -275,8 +280,9 @@ class StyleGANv2Generator(nn.Layer): self.channels[4], 3, style_dim, - blur_kernel=blur_kernel) - self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + blur_kernel=blur_kernel, + is_concat=is_concat) + self.to_rgb1 = ToRGB(self.channels[4]*2 if is_concat else self.channels[4], style_dim, upsample=False) self.log_size = int(math.log(size, 2)) self.num_layers = (self.log_size - 2) * 2 + 1 @@ -299,26 +305,29 @@ class StyleGANv2Generator(nn.Layer): self.convs.append( StyledConv( - in_channel, + in_channel*2 if is_concat else in_channel, out_channel, 3, style_dim, upsample=True, blur_kernel=blur_kernel, + is_concat=is_concat, )) self.convs.append( - StyledConv(out_channel, + StyledConv(out_channel*2 if is_concat else out_channel, out_channel, 3, style_dim, - blur_kernel=blur_kernel)) + blur_kernel=blur_kernel, + is_concat=is_concat)) - self.to_rgbs.append(ToRGB(out_channel, style_dim)) + self.to_rgbs.append(ToRGB(out_channel*2 if is_concat else out_channel, style_dim)) in_channel = out_channel self.n_latent = self.log_size * 2 - 2 + self.is_concat = is_concat def make_noise(self): noises = [paddle.randn((1, 1, 2**2, 2**2))] @@ -395,16 +404,29 @@ class StyleGANv2Generator(nn.Layer): skip = self.to_rgb1(out, latent[:, 1]) i = 1 - for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], - self.convs[1::2], - noise[1::2], - noise[2::2], - self.to_rgbs): - out = conv1(out, latent[:, i], noise=noise1) - out = conv2(out, latent[:, i + 1], noise=noise2) - skip = to_rgb(out, latent[:, i + 2], skip) - - i += 2 + if self.is_concat: + noise_i = 1 + + outs = [] + for conv1, conv2, to_rgb in zip( + self.convs[::2], self.convs[1::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise[(noise_i + 1)//2]) ### 1 for 2 + out = conv2(out, latent[:, i + 1], noise=noise[(noise_i + 2)//2]) ### 1 for 2 + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + noise_i += 2 + else: + for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], + self.convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 image = skip diff --git a/ppgan/models/generators/gpen.py b/ppgan/models/generators/gpen.py new file mode 100644 index 0000000..df72662 --- /dev/null +++ b/ppgan/models/generators/gpen.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# code was heavily based on https://github.com/yangxy/GPEN + +import paddle +import paddle.nn as nn +import math +from ppgan.models.generators import StyleGANv2Generator +from ppgan.models.discriminators.discriminator_styleganv2 import ConvLayer +from ppgan.modules.equalized import EqualLinear + +class GPEN(nn.Layer): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super(GPEN, self).__init__() + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.log_size = int(math.log(size, 2)) + self.generator = StyleGANv2Generator(size, + style_dim, + n_mlp, + channel_multiplier=channel_multiplier, + blur_kernel=blur_kernel, + lr_mlp=lr_mlp, + is_concat=True) + + conv = [ConvLayer(3, channels[size], 1)] + self.ecd0 = nn.Sequential(*conv) + in_channel = channels[size] + + self.names = ['ecd%d'%i for i in range(self.log_size-1)] + for i in range(self.log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] + setattr(self, self.names[self.log_size-i+1], nn.Sequential(*conv)) + in_channel = out_channel + self.final_linear = nn.Sequential(EqualLinear(channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) + + def forward(self, + inputs, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + ): + noise = [] + for i in range(self.log_size-1): + ecd = getattr(self, self.names[i]) + inputs = ecd(inputs) + noise.append(inputs) + inputs = inputs.reshape([inputs.shape[0], -1]) + outs = self.final_linear(inputs) + outs = self.generator([outs], return_latents, inject_index, truncation, + truncation_latent, input_is_latent, + noise=noise[::-1]) + return outs + + diff --git a/ppgan/models/generators/occlusion_aware.py b/ppgan/models/generators/occlusion_aware.py index 0e01d9f..1ce8aa1 100644 --- a/ppgan/models/generators/occlusion_aware.py +++ b/ppgan/models/generators/occlusion_aware.py @@ -17,8 +17,10 @@ import paddle from paddle import nn import paddle.nn.functional as F -from ...modules.first_order import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d +from ...modules.first_order import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, make_coordinate_grid from ...modules.dense_motion import DenseMotionNetwork +import numpy as np +import cv2 class OcclusionAwareGenerator(nn.Layer): @@ -35,7 +37,8 @@ class OcclusionAwareGenerator(nn.Layer): num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, - estimate_jacobian=False): + estimate_jacobian=False, + inference=False): super(OcclusionAwareGenerator, self).__init__() if dense_motion_params is not None: @@ -89,6 +92,8 @@ class OcclusionAwareGenerator(nn.Layer): padding=(3, 3)) self.estimate_occlusion_map = estimate_occlusion_map self.num_channels = num_channels + self.inference = inference + self.pad = 5 def deform_input(self, inp, deformation): _, h_old, w_old, _ = deformation.shape @@ -100,6 +105,16 @@ class OcclusionAwareGenerator(nn.Layer): mode='bilinear', align_corners=False) deformation = deformation.transpose([0, 2, 3, 1]) + if self.inference: + identity_grid = make_coordinate_grid((h, w), + type=inp.dtype) + identity_grid = identity_grid.reshape([1, h, w, 2]) + visualization_matrix = np.zeros((h,w)).astype("float32") + visualization_matrix[self.pad:h-self.pad, self.pad:w-self.pad] = 1.0 + gauss_kernel = paddle.to_tensor(cv2.GaussianBlur(visualization_matrix , (9, 9), 0.0, borderType=cv2.BORDER_ISOLATED)) + gauss_kernel = gauss_kernel.unsqueeze(0).unsqueeze(-1) + deformation = gauss_kernel * deformation + (1-gauss_kernel) * identity_grid + return F.grid_sample(inp, deformation, mode='bilinear', @@ -136,6 +151,12 @@ class OcclusionAwareGenerator(nn.Layer): size=out.shape[2:], mode='bilinear', align_corners=False) + if self.inference: + h,w = occlusion_map.shape[2:] + occlusion_map[:,:,0:self.pad,:] = 1.0 + occlusion_map[:,:,:,0:self.pad] = 1.0 + occlusion_map[:,:,h-self.pad:h,:] = 1.0 + occlusion_map[:,:,:,w-self.pad:w] = 1.0 out = out * occlusion_map output_dict["deformed"] = self.deform_input(source_image, -- GitLab