diff --git a/applications/tools/pixel2style2pixel.py b/applications/tools/pixel2style2pixel.py new file mode 100644 index 0000000000000000000000000000000000000000..69a2452d57eb3e40571f1e594d9073773e2fa902 --- /dev/null +++ b/applications/tools/pixel2style2pixel.py @@ -0,0 +1,72 @@ +import paddle +import os +import sys +sys.path.insert(0, os.getcwd()) +from ppgan.apps import Pixel2Style2PixelPredictor +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_image", type=str, help="path to source image") + + parser.add_argument("--output_path", + type=str, + default='output_dir', + help="path to output image dir") + + parser.add_argument("--weight_path", + type=str, + default=None, + help="path to model checkpoint path") + + parser.add_argument("--model_type", + type=str, + default=None, + help="type of model for loading pretrained model") + + parser.add_argument("--seed", + type=int, + default=None, + help="sample random seed for model's image generation") + + parser.add_argument("--size", + type=int, + default=1024, + help="resolution of output image") + + parser.add_argument("--style_dim", + type=int, + default=512, + help="number of style dimension") + + parser.add_argument("--n_mlp", + type=int, + default=8, + help="number of mlp layer depth") + + parser.add_argument("--channel_multiplier", + type=int, + default=2, + help="number of channel multiplier") + + parser.add_argument("--cpu", + dest="cpu", + action="store_true", + help="cpu mode.") + + args = parser.parse_args() + + if args.cpu: + paddle.set_device('cpu') + + predictor = Pixel2Style2PixelPredictor( + output_path=args.output_path, + weight_path=args.weight_path, + model_type=args.model_type, + seed=args.seed, + size=args.size, + style_dim=args.style_dim, + n_mlp=args.n_mlp, + channel_multiplier=args.channel_multiplier + ) + predictor.run(args.input_image) diff --git a/applications/tools/styleganv2.py b/applications/tools/styleganv2.py new file mode 100644 index 0000000000000000000000000000000000000000..55f792837c300ccd03ad785e111faa84ecf05818 --- /dev/null +++ b/applications/tools/styleganv2.py @@ -0,0 +1,80 @@ +import paddle +import os +import sys +sys.path.insert(0, os.getcwd()) +from ppgan.apps import StyleGANv2Predictor +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_path", + type=str, + default='output_dir', + help="path to output image dir") + + parser.add_argument("--weight_path", + type=str, + default=None, + help="path to model checkpoint path") + + parser.add_argument("--model_type", + type=str, + default=None, + help="type of model for loading pretrained model") + + parser.add_argument("--seed", + type=int, + default=None, + help="sample random seed for model's image generation") + + parser.add_argument("--size", + type=int, + default=1024, + help="resolution of output image") + + parser.add_argument("--style_dim", + type=int, + default=512, + help="number of style dimension") + + parser.add_argument("--n_mlp", + type=int, + default=8, + help="number of mlp layer depth") + + parser.add_argument("--channel_multiplier", + type=int, + default=2, + help="number of channel multiplier") + + parser.add_argument("--n_row", + type=int, + default=3, + help="row number of output image grid") + + parser.add_argument("--n_col", + type=int, + default=5, + help="column number of output image grid") + + parser.add_argument("--cpu", + dest="cpu", + action="store_true", + help="cpu mode.") + + args = parser.parse_args() + + if args.cpu: + paddle.set_device('cpu') + + predictor = StyleGANv2Predictor( + output_path=args.output_path, + weight_path=args.weight_path, + model_type=args.model_type, + seed=args.seed, + size=args.size, + style_dim=args.style_dim, + n_mlp=args.n_mlp, + channel_multiplier=args.channel_multiplier + ) + predictor.run(args.n_row, args.n_col) diff --git a/docs/en_US/tutorials/pixel2style2pixel.md b/docs/en_US/tutorials/pixel2style2pixel.md new file mode 100644 index 0000000000000000000000000000000000000000..2f37f9eb4d06b74c08365685da94d6186684a6fb --- /dev/null +++ b/docs/en_US/tutorials/pixel2style2pixel.md @@ -0,0 +1,87 @@ +# Pixel2Style2Pixel + +## Pixel2Style2Pixel introduction + +The task of Pixel2Style2Pixel is image encoding. It mainly encodes an input image as the style vector of StyleGAN V2 and uses StyleGAN V2 as the decoder. + +
+ +
+ +Pixel2Style2Pixel uses a fairly large model to encode images, and encodes the image into the style vector space of StyleGAN V2, so that the image before encoding and the image after decoding have a strong correlation. + +Its main functions are: + +- Convert image to hidden codes +- Turn face to face +- Generate images based on sketches or segmentation results +- Convert low-resolution images to high-definition images + +At present, only the models of portrait reconstruction and portrait cartoonization are realized in PaddleGAN. + +## How to use + +### Generate + +The user could use the following command to generate and select the local image as input: + +``` +cd applications/ +python -u tools/styleganv2.py \ + --input_image \ + --output_path \ + --weight_path \ + --model_type ffhq-inversion \ + --seed 233 \ + --size 1024 \ + --style_dim 512 \ + --n_mlp 8 \ + --channel_multiplier 2 \ + --cpu +``` + +**params:** +- input_image: the input image file path +- output_path: the directory where the generated images are stored +- weight_path: pretrained model path +- model_type: inner model type in PaddleGAN. If you use an existing model type, `weight_path` will have no effect. + Currently available: `ffhq-inversion`, `ffhq-toonify` +- seed: random number seed +- size: model parameters, output image resolution +- style_dim: model parameters, dimensions of style z +- n_mlp: model parameters, the number of multi-layer perception layers for style z +- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures +- cpu: whether to use cpu inference, if not, please remove it from the command + +### Train (TODO) + +In the future, training scripts will be added to facilitate users to train more types of Pixel2Style2Pixel image encoders. + + +## Results + +Input portrait: + +
+ +
+ +Cropped portrait-Reconstructed portrait-Cartoonized portrait: + +
+ + + +
+ +## Reference + +``` +@article{richardson2020encoding, + title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation}, + author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel}, + journal={arXiv preprint arXiv:2008.00951}, + year={2020} +} + +``` diff --git a/docs/en_US/tutorials/styleganv2.md b/docs/en_US/tutorials/styleganv2.md new file mode 100644 index 0000000000000000000000000000000000000000..dc742c391418ddf486196c2ae9b5aa1440175bfd --- /dev/null +++ b/docs/en_US/tutorials/styleganv2.md @@ -0,0 +1,83 @@ +# StyleGAN V2 + +## StyleGAN V2 introduction + +The task of StyleGAN V2 is image generation. Given a vector of a specific length, generate the image corresponding to the vector. It is an upgraded version of StyleGAN, which solves the problem of artifacts generated by StyleGAN. + +
+ +
+ +StyleGAN V2 can mix multi-level style vectors. Its core is adaptive style decoupling. + +Compared with StyleGAN, its main improvement is: + +- The quality of the generated image is significantly better (higher FID score, fewer artifacts) +- Propose a new method to replace progressive training, with more perfect details such as teeth and eyes +- Style mixing improved +- Smoother interpolation +- Train faster + +## How to use + +### Generate + +The user can generate different results by replacing the value of the seed or removing the seed. Use the following command to generate images: + +``` +cd applications/ +python -u tools/styleganv2.py \ + --output_path \ + --weight_path \ + --model_type ffhq-config-f \ + --seed 233 \ + --size 1024 \ + --style_dim 512 \ + --n_mlp 8 \ + --channel_multiplier 2 \ + --n_row 3 \ + --n_col 5 \ + --cpu +``` + +**params:** +- output_path: the directory where the generated images are stored +- weight_path: pretrained model path +- model_type: inner model type in PaddleGAN. If you use an existing model type, `weight_path` will have no effect. + Currently available: `ffhq-config-f`, `animeface-512` +- seed: random number seed +- size: model parameters, output image resolution +- style_dim: model parameters, dimensions of style z +- n_mlp: model parameters, the number of multi-layer perception layers for style z +- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures +- n_row: the number of rows of the sampled image +- n_col: the number of columns of the sampled picture +- cpu: whether to use cpu inference, if not, please remove it from the command + +### Train (TODO) + +In the future, training scripts will be added to facilitate users to train more types of StyleGAN V2 image generators. + + +## Results + +Random Samples: + +![Samples](../../imgs/stylegan2-sample.png) + +Random Style Mixing: + +![Random Style Mixing](../../imgs/stylegan2-sample-mixing-0.png) + + +## Reference + +``` +@inproceedings{Karras2019stylegan2, + title = {Analyzing and Improving the Image Quality of {StyleGAN}}, + author = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila}, + booktitle = {Proc. CVPR}, + year = {2020} +} + +``` diff --git a/docs/imgs/pSp-input-crop.png b/docs/imgs/pSp-input-crop.png new file mode 100644 index 0000000000000000000000000000000000000000..93173c080f689a57fef46df4299dcdb2f112167e Binary files /dev/null and b/docs/imgs/pSp-input-crop.png differ diff --git a/docs/imgs/pSp-input.jpg b/docs/imgs/pSp-input.jpg new file mode 100644 index 0000000000000000000000000000000000000000..10c4d68bea9c04fd30f1173c7f2930c6bde79c89 Binary files /dev/null and b/docs/imgs/pSp-input.jpg differ diff --git a/docs/imgs/pSp-inversion.png b/docs/imgs/pSp-inversion.png new file mode 100644 index 0000000000000000000000000000000000000000..60fdf1525e62051bd314d9ccdbbb4cecc6a618ee Binary files /dev/null and b/docs/imgs/pSp-inversion.png differ diff --git a/docs/imgs/pSp-teaser.jpg b/docs/imgs/pSp-teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..277d3afc4086ce094c78e264cadd76df793209d9 Binary files /dev/null and b/docs/imgs/pSp-teaser.jpg differ diff --git a/docs/imgs/pSp-toonify.png b/docs/imgs/pSp-toonify.png new file mode 100644 index 0000000000000000000000000000000000000000..d2d3cd892133e8c73ecdc1dc5c2f3591e4e3a3aa Binary files /dev/null and b/docs/imgs/pSp-toonify.png differ diff --git a/docs/imgs/stylegan2-sample-mixing-0.png b/docs/imgs/stylegan2-sample-mixing-0.png new file mode 100644 index 0000000000000000000000000000000000000000..699d8a212d001ec2b02859b4a9ae4aa18d916cee Binary files /dev/null and b/docs/imgs/stylegan2-sample-mixing-0.png differ diff --git a/docs/imgs/stylegan2-sample.png b/docs/imgs/stylegan2-sample.png new file mode 100644 index 0000000000000000000000000000000000000000..cd620ec312cd75200ac85d1c38e1ce7e2c13fa54 Binary files /dev/null and b/docs/imgs/stylegan2-sample.png differ diff --git a/docs/imgs/stylegan2-teaser-1024x256.png b/docs/imgs/stylegan2-teaser-1024x256.png new file mode 100644 index 0000000000000000000000000000000000000000..bb16c5f5c8b615983b36b2446564e654cc7805c3 Binary files /dev/null and b/docs/imgs/stylegan2-teaser-1024x256.png differ diff --git a/docs/zh_CN/tutorials/pixel2style2pixel.md b/docs/zh_CN/tutorials/pixel2style2pixel.md new file mode 100644 index 0000000000000000000000000000000000000000..09319653d0146555b3c37ab454c8199704d9b8be --- /dev/null +++ b/docs/zh_CN/tutorials/pixel2style2pixel.md @@ -0,0 +1,87 @@ +# Pixel2Style2Pixel + +## Pixel2Style2Pixel 原理 + +Pixel2Style2Pixel 的任务是image encoding。它主要是将图像编码为StyleGAN V2的风格向量,将StyleGAN V2当作解码器。 + +
+ +
+ +Pixel2Style2Pixel使用相当大的模型对图像进行编码,将图像编码到StyleGAN V2的风格向量空间中,使编码前的图像和解码后的图像具有强关联性。 + +它的主要功能有: + +- 将图像转成隐藏编码 +- 将人脸转正 +- 根据草图或者分割结果生成图像 +- 将低分辨率图像转成高清图像 + +目前在PaddleGAN中实现了人像重建和人像卡通化的模型。 + +## 使用方法 + +### 生成 + +用户使用如下命令中进行生成,选择本地图像作为输入: + +``` +cd applications/ +python -u tools/styleganv2.py \ + --input_image <替换为输入的图像路径> \ + --output_path <替换为生成图片存放的文件夹> \ + --weight_path <替换为你的预训练模型路径> \ + --model_type ffhq-inversion \ + --seed 233 \ + --size 1024 \ + --style_dim 512 \ + --n_mlp 8 \ + --channel_multiplier 2 \ + --cpu +``` + +**参数说明:** +- input_image: 输入的图像路径 +- output_path: 生成图片存放的文件夹 +- weight_path: 预训练模型路径 +- model_type: PaddleGAN内置模型类型,若输入PaddleGAN已存在的模型类型,`weight_path`将失效。 + 当前可用: `ffhq-inversion`, `ffhq-toonify` +- seed: 随机数种子 +- size: 模型参数,输出图片的分辨率 +- style_dim: 模型参数,风格z的维度 +- n_mlp: 模型参数,风格z所输入的多层感知层的层数 +- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量 +- cpu: 是否使用cpu推理,若不使用,请在命令中去除 + +### 训练(TODO) + +未来还将添加训练脚本方便用户训练出更多类型的 Pixel2Style2Pixel 图像编码器。 + + +## 生成结果展示 + +输入人像: + +
+ +
+ +裁剪人像-重建人像-卡通化人像: + +
+ + + +
+ +## 参考文献 + +``` +@article{richardson2020encoding, + title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation}, + author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel}, + journal={arXiv preprint arXiv:2008.00951}, + year={2020} +} + +``` diff --git a/docs/zh_CN/tutorials/styleganv2.md b/docs/zh_CN/tutorials/styleganv2.md new file mode 100644 index 0000000000000000000000000000000000000000..7ebab5e1ff14af2fdca8769515b40736491a6029 --- /dev/null +++ b/docs/zh_CN/tutorials/styleganv2.md @@ -0,0 +1,83 @@ +# StyleGAN V2 + +## StyleGAN V2 原理 + +StyleGAN V2 的任务是image generation,给定特定长度的向量,生成该向量对应的图像,是StyleGAN的升级版,解决了StyleGAN生成的伪像等问题。 + +
+ +
+ +StyleGAN V2 可对多级风格向量进行混合。其内核是自适应的风格解耦。 + +相对于StyleGAN,其主要改进为: + +- 生成的图像质量明显更好(FID分数更高、artifacts减少) +- 提出替代渐进式训练的新方法,牙齿、眼睛等细节更完美 +- 改善了风格混合 +- 更平滑的插值 +- 训练速度更快 + +## 使用方法 + +### 生成 + +用户使用如下命令中进行生成,可通过替换seed的值或去掉seed生成不同的结果: + +``` +cd applications/ +python -u tools/styleganv2.py \ + --output_path <替换为生成图片存放的文件夹> \ + --weight_path <替换为你的预训练模型路径> \ + --model_type ffhq-config-f \ + --seed 233 \ + --size 1024 \ + --style_dim 512 \ + --n_mlp 8 \ + --channel_multiplier 2 \ + --n_row 3 \ + --n_col 5 \ + --cpu +``` + +**参数说明:** +- output_path: 生成图片存放的文件夹 +- weight_path: 预训练模型路径 +- model_type: PaddleGAN内置模型类型,若输入PaddleGAN已存在的模型类型,`weight_path`将失效。 + 当前可用: `ffhq-config-f`, `animeface-512` +- seed: 随机数种子 +- size: 模型参数,输出图片的分辨率 +- style_dim: 模型参数,风格z的维度 +- n_mlp: 模型参数,风格z所输入的多层感知层的层数 +- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量 +- n_row: 采样的图片的行数 +- n_col: 采样的图片的列数 +- cpu: 是否使用cpu推理,若不使用,请在命令中去除 + +### 训练(TODO) + +未来还将添加训练脚本方便用户训练出更多类型的 StyleGAN V2 图像生成器。 + + +## 生成结果展示 + +随机采样结果: + +![随机采样结果](../../imgs/stylegan2-sample.png) + +随机风格插值结果: + +![随机风格插值结果](../../imgs/stylegan2-sample-mixing-0.png) + + +## 参考文献 + +``` +@inproceedings{Karras2019stylegan2, + title = {Analyzing and Improving the Image Quality of {StyleGAN}}, + author = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila}, + booktitle = {Proc. CVPR}, + year = {2020} +} + +``` diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index 8704748b74ce59476db0886e9f2e691ef2698dc9..a0aaaf0dc4444574fad45f30768f95b3d7d57af7 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -21,3 +21,5 @@ from .first_order_predictor import FirstOrderPredictor from .face_parse_predictor import FaceParsePredictor from .animegan_predictor import AnimeGANPredictor from .midas_predictor import MiDaSPredictor +from .styleganv2_predictor import StyleGANv2Predictor +from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor diff --git a/ppgan/apps/pixel2style2pixel_predictor.py b/ppgan/apps/pixel2style2pixel_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..b3722a9111cf6860fe49771f7fc5b83319b7f4ff --- /dev/null +++ b/ppgan/apps/pixel2style2pixel_predictor.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import scipy +import random +import numpy as np +import paddle +import paddle.vision.transforms as T +import ppgan.faceutils as futils +from .base_predictor import BasePredictor +from ppgan.models.generators import Pixel2Style2Pixel +from ppgan.utils.download import get_path_from_url +from PIL import Image + + +model_cfgs = { + 'ffhq-inversion': { + 'model_urls': 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-inversion.pdparams', + 'transform': T.Compose([ + T.Resize((256, 256)), + T.Transpose(), + T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]) + ]), + 'size': 1024, + 'style_dim': 512, + 'n_mlp': 8, + 'channel_multiplier': 2 + }, + 'ffhq-toonify': { + 'model_urls': 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-toonify.pdparams', + 'transform': T.Compose([ + T.Resize((256, 256)), + T.Transpose(), + T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]) + ]), + 'size': 1024, + 'style_dim': 512, + 'n_mlp': 8, + 'channel_multiplier': 2 + }, + 'default': { + 'transform': T.Compose([ + T.Resize((256, 256)), + T.Transpose(), + T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]) + ]) + } +} + + +def run_alignment(image_path): + img = Image.open(image_path).convert("RGB") + face = futils.dlib.detect(img) + if not face: + raise Exception('Could not find a face in the given image.') + face_on_image = face[0] + lm = futils.dlib.landmarks(img, face_on_image) + lm = np.array(lm)[:,::-1] + lm_eye_left = lm[36 : 42] + lm_eye_right = lm[42 : 48] + lm_mouth_outer = lm[48 : 60] + + output_size = 1024 + transform_size = 4096 + enable_padding = True + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) + img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR) + + return img + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class Pixel2Style2PixelPredictor(BasePredictor): + def __init__(self, + output_path='output_dir', + weight_path=None, + model_type=None, + seed=None, + size=1024, + style_dim=512, + n_mlp=8, + channel_multiplier=2): + self.output_path = output_path + + if weight_path is None and model_type != 'default': + if model_type in model_cfgs.keys(): + weight_path = get_path_from_url(model_cfgs[model_type]['model_urls']) + size = model_cfgs[model_type].get('size', size) + style_dim = model_cfgs[model_type].get('style_dim', style_dim) + n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp) + channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier) + checkpoint = paddle.load(weight_path) + else: + raise ValueError('Predictor need a weight path or a pretrained model type') + else: + checkpoint = paddle.load(weight_path) + + opts = checkpoint.pop('opts') + opts = AttrDict(opts) + opts['size'] = size + opts['style_dim'] = style_dim + opts['n_mlp'] = n_mlp + opts['channel_multiplier'] = channel_multiplier + + self.generator = Pixel2Style2Pixel(opts) + self.generator.set_state_dict(checkpoint) + self.generator.eval() + + if seed is not None: + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + self.model_type = 'default' if model_type is None else model_type + + def run(self, image): + src_img = run_alignment(image) + src_img = np.asarray(src_img) + transformed_image = model_cfgs[self.model_type]['transform'](src_img) + dst_img = (self.generator(paddle.to_tensor(transformed_image[None, ...])) + * 0.5 + 0.5)[0].numpy() * 255 + dst_img = dst_img.transpose((1, 2, 0)) + + os.makedirs(self.output_path, exist_ok=True) + save_src_path = os.path.join(self.output_path, 'src.png') + cv2.imwrite(save_src_path, cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)) + save_dst_path = os.path.join(self.output_path, 'dst.png') + cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR)) + + return src_img diff --git a/ppgan/apps/styleganv2_predictor.py b/ppgan/apps/styleganv2_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..c9626967735d3ddf395c3f417bc0a92687f65339 --- /dev/null +++ b/ppgan/apps/styleganv2_predictor.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import numpy as np +import paddle +from .base_predictor import BasePredictor +from ppgan.models.generators import StyleGANv2Generator +from ppgan.utils.download import get_path_from_url +from ppgan.utils.visual import make_grid, tensor2img, save_image + + +model_cfgs = { + 'ffhq-config-f': { + 'model_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f.pdparams', + 'size': 1024, + 'style_dim': 512, + 'n_mlp': 8, + 'channel_multiplier': 2 + }, + 'animeface-512': { + 'model_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-animeface-512.pdparams', + 'size': 512, + 'style_dim': 512, + 'n_mlp': 8, + 'channel_multiplier': 2 + } +} + + +@paddle.no_grad() +def get_mean_style(generator): + mean_style = None + + for i in range(10): + style = generator.mean_latent(1024) + + if mean_style is None: + mean_style = style + + else: + mean_style += style + + mean_style /= 10 + return mean_style + + +@paddle.no_grad() +def sample(generator, mean_style, n_sample): + image = generator( + [paddle.randn([n_sample, generator.style_dim])], + truncation=0.7, + truncation_latent=mean_style, + )[0] + + return image + + +@paddle.no_grad() +def style_mixing(generator, mean_style, n_source, n_target): + source_code = paddle.randn([n_source, generator.style_dim]) + target_code = paddle.randn([n_target, generator.style_dim]) + + resolution = 2 ** ((generator.n_latent + 2) // 2) + + images = [paddle.ones([1, 3, resolution, resolution]) * -1] + + source_image = generator( + [source_code], truncation_latent=mean_style, truncation=0.7 + )[0] + target_image = generator( + [target_code], truncation_latent=mean_style, truncation=0.7 + )[0] + + images.append(source_image) + + for i in range(n_target): + image = generator( + [target_code[i].unsqueeze(0).tile([n_source, 1]), source_code], + truncation_latent=mean_style, + truncation=0.7, + )[0] + images.append(target_image[i].unsqueeze(0)) + images.append(image) + + images = paddle.concat(images, 0) + + return images + + +class StyleGANv2Predictor(BasePredictor): + def __init__(self, + output_path='output_dir', + weight_path=None, + model_type=None, + seed=None, + size=1024, + style_dim=512, + n_mlp=8, + channel_multiplier=2): + self.output_path = output_path + + if weight_path is None: + if model_type in model_cfgs.keys(): + weight_path = get_path_from_url(model_cfgs[model_type]['model_urls']) + size = model_cfgs[model_type].get('size', size) + style_dim = model_cfgs[model_type].get('style_dim', style_dim) + n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp) + channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier) + checkpoint = paddle.load(weight_path) + else: + raise ValueError('Predictor need a weight path or a pretrained model type') + else: + checkpoint = paddle.load(weight_path) + + self.generator = StyleGANv2Generator(size, style_dim, n_mlp, channel_multiplier) + self.generator.set_state_dict(checkpoint) + self.generator.eval() + + if seed is not None: + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + def run(self, n_row=3, n_col=5): + os.makedirs(self.output_path, exist_ok=True) + mean_style = get_mean_style(self.generator) + + img = sample(self.generator, mean_style, n_row * n_col) + save_image(tensor2img(make_grid(img, nrow=n_col)), f'{self.output_path}/sample.png') + + for j in range(2): + img = style_mixing(self.generator, mean_style, n_col, n_row) + save_image(tensor2img(make_grid( + img, nrow=n_col + 1 + )), f'{self.output_path}/sample_mixing_{j}.png') diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index 41c23b5210ab737d6b31b0db2daec6d1636792b9..f7af297488ab8f97eaf1015d56019cd5e4abad03 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -17,3 +17,4 @@ from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification from .discriminator_ugatit import UGATITDiscriminator from .dcdiscriminator import DCDiscriminator from .discriminator_animegan import AnimeDiscriminator +from .discriminator_styleganv2 import StyleGANv2Discriminator diff --git a/ppgan/models/discriminators/discriminator_styleganv2.py b/ppgan/models/discriminators/discriminator_styleganv2.py new file mode 100644 index 0000000000000000000000000000000000000000..a06e1f60927d02de8021343daf345a4bd78b66fa --- /dev/null +++ b/ppgan/models/discriminators/discriminator_styleganv2.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .builder import DISCRIMINATORS +from ...modules.equalized import EqualLinear, EqualConv2D +from ...modules.fused_act import FusedLeakyReLU +from ...modules.upfirdn2d import Upfirdn2dBlur + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2D( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + layers.append(FusedLeakyReLU(out_channel, bias=bias)) + + super().__init__(*layers) + + +class ResBlock(nn.Layer): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +@DISCRIMINATORS.register() +class StyleGANv2Discriminator(nn.Layer): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.reshape(( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + )) + stddev = paddle.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2) + stddev = stddev.tile((group, 1, height, width)) + out = paddle.concat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.reshape((batch, -1)) + out = self.final_linear(out) + + return out diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index ad04c1cdf2d164eaa62b288a902aedf6594dcdb9..8c0feda68125f18d42e01f013aeb49795ab0a5e9 100644 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -20,4 +20,6 @@ from .deep_conv import DeepConvGenerator, ConditionalDeepConvGenerator from .resnet_ugatit import ResnetUGATITGenerator from .dcgenerator import DCGenerator from .generater_animegan import AnimeGenerator, AnimeGeneratorLite -from .wav2lip import Wav2Lip \ No newline at end of file +from .wav2lip import Wav2Lip +from .generator_styleganv2 import StyleGANv2Generator +from .generator_pixel2style2pixel import Pixel2Style2Pixel diff --git a/ppgan/models/generators/generator_pixel2style2pixel.py b/ppgan/models/generators/generator_pixel2style2pixel.py new file mode 100644 index 0000000000000000000000000000000000000000..1651cc54c01b45df3a837d51e53a295f4a45b199 --- /dev/null +++ b/ppgan/models/generators/generator_pixel2style2pixel.py @@ -0,0 +1,384 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from collections import namedtuple + +from .builder import GENERATORS +from .generator_styleganv2 import StyleGANv2Generator +from ...modules.equalized import EqualLinear + + +class Flatten(nn.Layer): + def forward(self, input): + return input.reshape((input.shape[0], -1)) + + +def l2_norm(input, axis=1): + norm = paddle.norm(input, 2, axis, True) + output = paddle.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(nn.Layer): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.fc1 = nn.Conv2D(channels, channels // reduction, kernel_size=1, padding=0, bias_attr=False) + self.relu = nn.ReLU() + self.fc2 = nn.Conv2D(channels // reduction, channels, kernel_size=1, padding=0, bias_attr=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class BottleneckIR(nn.Layer): + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + if in_channel == depth: + self.shortcut_layer = nn.MaxPool2D(1, stride) + else: + self.shortcut_layer = nn.Sequential( + nn.Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False), + nn.BatchNorm2D(depth) + ) + self.res_layer = nn.Sequential( + nn.BatchNorm2D(in_channel), + nn.Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False), nn.PReLU(depth), + nn.Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False), nn.BatchNorm2D(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class BottleneckIRSE(nn.Layer): + def __init__(self, in_channel, depth, stride): + super(BottleneckIRSE, self).__init__() + if in_channel == depth: + self.shortcut_layer = nn.MaxPool2D(1, stride) + else: + self.shortcut_layer = nn.Sequential( + nn.Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False), + nn.BatchNorm2D(depth) + ) + self.res_layer = nn.Sequential( + nn.BatchNorm2D(in_channel), + nn.Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False), + nn.PReLU(depth), + nn.Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False), + nn.BatchNorm2D(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class GradualStyleBlock(nn.Layer): + def __init__(self, in_c, out_c, spatial): + super(GradualStyleBlock, self).__init__() + self.out_c = out_c + self.spatial = spatial + num_pools = int(np.log2(spatial)) + modules = [] + modules += [nn.Conv2D(in_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU()] + for i in range(num_pools - 1): + modules += [ + nn.Conv2D(out_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU() + ] + self.convs = nn.Sequential(*modules) + self.linear = EqualLinear(out_c, out_c, lr_mul=1) + + def forward(self, x): + x = self.convs(x) + x = x.reshape((-1, self.out_c)) + x = self.linear(x) + return x + + +class GradualStyleEncoder(nn.Layer): + def __init__(self, num_layers, mode='ir', opts=None): + super(GradualStyleEncoder, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + self.input_layer = nn.Sequential(nn.Conv2D(opts.input_nc, 64, (3, 3), 1, 1, bias_attr=False), + nn.BatchNorm2D(64), + nn.PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = nn.Sequential(*modules) + + self.styles = nn.LayerList() + self.style_count = 18 + self.coarse_ind = 3 + self.middle_ind = 7 + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + self.latlayer1 = nn.Conv2D(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2D(128, 512, kernel_size=1, stride=1, padding=0) + + def _upsample_add(self, x, y): + '''Upsample and add two feature maps. + Args: + x: (Tensor) top feature map to be upsampled. + y: (Tensor) lateral feature map. + Returns: + (Tensor) added feature map. + Note in Pypaddle, when input size is odd, the upsampled feature map + with `F.upsample(..., scale_factor=2, mode='nearest')` + maybe not equal to the lateral feature map size. + e.g. + original input size: [N,_,15,15] -> + conv2d feature map size: [N,_,8,8] -> + upsampled feature map size: [N,_,16,16] + So we choose bilinear upsample which supports arbitrary output sizes. + ''' + _, _, H, W = y.shape + return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._sub_layers.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + for j in range(self.coarse_ind): + latents.append(self.styles[j](c3)) + + p2 = self._upsample_add(c3, self.latlayer1(c2)) + for j in range(self.coarse_ind, self.middle_ind): + latents.append(self.styles[j](p2)) + + p1 = self._upsample_add(p2, self.latlayer2(c1)) + for j in range(self.middle_ind, self.style_count): + latents.append(self.styles[j](p1)) + + out = paddle.stack(latents, 1) + return out + + +class BackboneEncoderUsingLastLayerIntoW(nn.Layer): + def __init__(self, num_layers, mode='ir', opts=None): + super(BackboneEncoderUsingLastLayerIntoW, self).__init__() + print('Using BackboneEncoderUsingLastLayerIntoW') + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + self.input_layer = nn.Sequential(nn.Conv2D(opts.input_nc, 64, (3, 3), 1, 1, bias_attr=False), + nn.BatchNorm2D(64), + nn.PReLU(64)) + self.output_pool = nn.AdaptiveAvgPool2D((1, 1)) + self.linear = EqualLinear(512, 512, lr_mul=1) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = nn.Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_pool(x) + x = x.reshape((-1, 512)) + x = self.linear(x) + return x + + +class BackboneEncoderUsingLastLayerIntoWPlus(nn.Layer): + def __init__(self, num_layers, mode='ir', opts=None): + super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__() + print('Using BackboneEncoderUsingLastLayerIntoWPlus') + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + self.input_layer = nn.Sequential(nn.Conv2D(opts.input_nc, 64, (3, 3), 1, 1, bias_attr=False), + nn.BatchNorm2D(64), + nn.PReLU(64)) + self.output_layer_2 = nn.Sequential(nn.BatchNorm2D(512), + nn.AdaptiveAvgPool2D((7, 7)), + Flatten(), + nn.Linear(512 * 7 * 7, 512)) + self.linear = EqualLinear(512, 512 * 18, lr_mul=1) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = nn.Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer_2(x) + x = self.linear(x) + x = x.reshape((-1, 18, 512)) + return x + + +@GENERATORS.register() +class Pixel2Style2Pixel(nn.Layer): + def __init__(self, opts): + super(Pixel2Style2Pixel, self).__init__() + self.set_opts(opts) + # Define architecture + self.encoder = self.set_encoder() + self.decoder = StyleGANv2Generator(opts.size, opts.style_dim, opts.n_mlp, opts.channel_multiplier) + self.face_pool = nn.AdaptiveAvgPool2D((256, 256)) + self.style_dim = self.decoder.style_dim + self.n_latent = self.decoder.n_latent + if self.opts.start_from_latent_avg: + if self.opts.learn_in_w: + self.register_buffer('latent_avg', paddle.zeros([1, self.style_dim])) + else: + self.register_buffer('latent_avg', paddle.zeros([1, self.n_latent, self.style_dim])) + + def set_encoder(self): + if self.opts.encoder_type == 'GradualStyleEncoder': + encoder = GradualStyleEncoder(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': + encoder = BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': + encoder = BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts) + else: + raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) + return encoder + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None): + if input_code: + codes = x + else: + codes = self.encoder(x) + # normalize with respect to the center of an average face + if self.opts.start_from_latent_avg: + if self.opts.learn_in_w: + codes = codes + self.latent_avg.tile([codes.shape[0], 1]) + else: + codes = codes + self.latent_avg.tile([codes.shape[0], 1, 1]) + + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = not input_code + images, result_latent = self.decoder([codes], + input_is_latent=input_is_latent, + randomize_noise=randomize_noise, + return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + if return_latents: + return images, result_latent + else: + return images + + def set_opts(self, opts): + self.opts = opts diff --git a/ppgan/models/generators/generator_styleganv2.py b/ppgan/models/generators/generator_styleganv2.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0ccbaaf4cfa969792523de6ff4439876e41c09 --- /dev/null +++ b/ppgan/models/generators/generator_styleganv2.py @@ -0,0 +1,395 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .builder import GENERATORS +from ...modules.equalized import EqualLinear +from ...modules.fused_act import FusedLeakyReLU +from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur + + +class PixelNorm(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * paddle.rsqrt(paddle.mean(input ** 2, 1, keepdim=True) + 1e-8) + + +class ModulatedConv2D(nn.Layer): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = self.create_parameter( + (1, out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal() + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " + f"upsample={self.upsample}, downsample={self.downsample})" + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1)) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1)) + + weight = weight.reshape(( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + )) + + if self.upsample: + input = input.reshape((1, batch * in_channel, height, width)) + weight = weight.reshape(( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + )) + weight = weight.transpose((0, 2, 1, 3, 4)).reshape(( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + )) + out = F.conv2d_transpose(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.reshape((batch, self.out_channel, height, width)) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.reshape((1, batch * in_channel, height, width)) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.reshape((batch, self.out_channel, height, width)) + + else: + input = input.reshape((1, batch * in_channel, height, width)) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.reshape((batch, self.out_channel, height, width)) + + return out + + +class NoiseInjection(nn.Layer): + def __init__(self): + super().__init__() + + self.weight = self.create_parameter((1,), default_initializer=nn.initializer.Constant(0.0)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = paddle.randn((batch, 1, height, width)) + + return image + self.weight * noise + + +class ConstantInput(nn.Layer): + def __init__(self, channel, size=4): + super().__init__() + + self.input = self.create_parameter((1, channel, size, size), default_initializer=nn.initializer.Normal()) + + def forward(self, input): + batch = input.shape[0] + out = self.input.tile((batch, 1, 1, 1)) + + return out + + +class StyledConv(nn.Layer): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2D( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + out = self.activate(out) + + return out + + +class ToRGB(nn.Layer): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upfirdn2dUpsample(blur_kernel) + + self.conv = ModulatedConv2D(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = self.create_parameter((1, 3, 1, 1), nn.initializer.Constant(0.0)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +@GENERATORS.register() +class StyleGANv2Generator(nn.Layer): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.LayerList() + self.upsamples = nn.LayerList() + self.to_rgbs = nn.LayerList() + self.noises = nn.Layer() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f"noise_{layer_idx}", paddle.randn(shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + noises = [paddle.randn((1, 1, 2 ** 2, 2 ** 2))] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(paddle.randn((1, 1, 2 ** i, 2 ** i))) + + return noises + + def mean_latent(self, n_latent): + latent_in = paddle.randn(( + n_latent, self.style_dim + )) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).tile((1, inject_index, 1)) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).tile((1, inject_index, 1)) + latent2 = styles[1].unsqueeze(1).tile((1, self.n_latent - inject_index, 1)) + + latent = paddle.concat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None diff --git a/ppgan/modules/equalized.py b/ppgan/modules/equalized.py new file mode 100644 index 0000000000000000000000000000000000000000..7280ab0e212f7309f2125e19e83cca59b096f31e --- /dev/null +++ b/ppgan/modules/equalized.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .fused_act import fused_leaky_relu + + +class EqualConv2D(nn.Layer): + """This convolutional layer class stabilizes the learning rate changes of its parameters. + Equalizing learning rate keeps the weights in the network at a similar scale during training. + """ + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = self.create_parameter( + (out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal() + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = self.create_parameter((out_channel,), nn.initializer.Constant(0.0)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class EqualLinear(nn.Layer): + """This linear layer class stabilizes the learning rate changes of its parameters. + Equalizing learning rate keeps the weights in the network at a similar scale during training. + """ + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = self.create_parameter((in_dim, out_dim), default_initializer=nn.initializer.Normal()) + self.weight[:] = (self.weight / lr_mul).detach() + + if bias: + self.bias = self.create_parameter((out_dim,), nn.initializer.Constant(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})" + ) diff --git a/ppgan/modules/fused_act.py b/ppgan/modules/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..8723af36c2799c5f3e82d6d4b2baccf70a347cce --- /dev/null +++ b/ppgan/modules/fused_act.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class FusedLeakyReLU(nn.Layer): + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + if bias: + self.bias = self.create_parameter((channel,), default_initializer=nn.initializer.Constant(0.0)) + + else: + self.bias = None + + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.reshape((1, bias.shape[0], *rest_dim)), negative_slope=0.2 + ) + * scale + ) + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale diff --git a/ppgan/modules/upfirdn2d.py b/ppgan/modules/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..856378a62dd14c613787fd9ecead77d036b27467 --- /dev/null +++ b/ppgan/modules/upfirdn2d.py @@ -0,0 +1,143 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape((-1, in_h, in_w, 1)) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.reshape((-1, in_h, 1, in_w, 1, minor)) + out = out.transpose((0,1,3,5,2,4)) + out = out.reshape((-1,1,1,1)) + out = F.pad(out, [0, up_x - 1, 0, up_y - 1]) + out = out.reshape((-1, in_h, in_w, minor, up_y, up_x)) + out = out.transpose((0,3,1,4,2,5)) + out = out.reshape((-1, minor, in_h * up_y, in_w * up_x)) + + out = F.pad( + out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :,:, + max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0), + ] + + out = out.reshape(( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + )) + w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w)) + out = F.conv2d(out, w) + out = out.reshape(( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + )) + out = out.transpose((0, 2, 3, 1)) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.reshape((-1, channel, out_h, out_w)) + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + return out + + +def make_kernel(k): + k = paddle.to_tensor(k, dtype='float32') + + if k.ndim == 1: + k = k.unsqueeze(0) * k.unsqueeze(1) + + k /= k.sum() + + return k + + +class Upfirdn2dUpsample(nn.Layer): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Upfirdn2dDownsample(nn.Layer): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Upfirdn2dBlur(nn.Layer): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer("kernel", kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out