diff --git a/applications/tools/pixel2style2pixel.py b/applications/tools/pixel2style2pixel.py
new file mode 100644
index 0000000000000000000000000000000000000000..69a2452d57eb3e40571f1e594d9073773e2fa902
--- /dev/null
+++ b/applications/tools/pixel2style2pixel.py
@@ -0,0 +1,72 @@
+import paddle
+import os
+import sys
+sys.path.insert(0, os.getcwd())
+from ppgan.apps import Pixel2Style2PixelPredictor
+import argparse
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_image", type=str, help="path to source image")
+
+ parser.add_argument("--output_path",
+ type=str,
+ default='output_dir',
+ help="path to output image dir")
+
+ parser.add_argument("--weight_path",
+ type=str,
+ default=None,
+ help="path to model checkpoint path")
+
+ parser.add_argument("--model_type",
+ type=str,
+ default=None,
+ help="type of model for loading pretrained model")
+
+ parser.add_argument("--seed",
+ type=int,
+ default=None,
+ help="sample random seed for model's image generation")
+
+ parser.add_argument("--size",
+ type=int,
+ default=1024,
+ help="resolution of output image")
+
+ parser.add_argument("--style_dim",
+ type=int,
+ default=512,
+ help="number of style dimension")
+
+ parser.add_argument("--n_mlp",
+ type=int,
+ default=8,
+ help="number of mlp layer depth")
+
+ parser.add_argument("--channel_multiplier",
+ type=int,
+ default=2,
+ help="number of channel multiplier")
+
+ parser.add_argument("--cpu",
+ dest="cpu",
+ action="store_true",
+ help="cpu mode.")
+
+ args = parser.parse_args()
+
+ if args.cpu:
+ paddle.set_device('cpu')
+
+ predictor = Pixel2Style2PixelPredictor(
+ output_path=args.output_path,
+ weight_path=args.weight_path,
+ model_type=args.model_type,
+ seed=args.seed,
+ size=args.size,
+ style_dim=args.style_dim,
+ n_mlp=args.n_mlp,
+ channel_multiplier=args.channel_multiplier
+ )
+ predictor.run(args.input_image)
diff --git a/applications/tools/styleganv2.py b/applications/tools/styleganv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..55f792837c300ccd03ad785e111faa84ecf05818
--- /dev/null
+++ b/applications/tools/styleganv2.py
@@ -0,0 +1,80 @@
+import paddle
+import os
+import sys
+sys.path.insert(0, os.getcwd())
+from ppgan.apps import StyleGANv2Predictor
+import argparse
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--output_path",
+ type=str,
+ default='output_dir',
+ help="path to output image dir")
+
+ parser.add_argument("--weight_path",
+ type=str,
+ default=None,
+ help="path to model checkpoint path")
+
+ parser.add_argument("--model_type",
+ type=str,
+ default=None,
+ help="type of model for loading pretrained model")
+
+ parser.add_argument("--seed",
+ type=int,
+ default=None,
+ help="sample random seed for model's image generation")
+
+ parser.add_argument("--size",
+ type=int,
+ default=1024,
+ help="resolution of output image")
+
+ parser.add_argument("--style_dim",
+ type=int,
+ default=512,
+ help="number of style dimension")
+
+ parser.add_argument("--n_mlp",
+ type=int,
+ default=8,
+ help="number of mlp layer depth")
+
+ parser.add_argument("--channel_multiplier",
+ type=int,
+ default=2,
+ help="number of channel multiplier")
+
+ parser.add_argument("--n_row",
+ type=int,
+ default=3,
+ help="row number of output image grid")
+
+ parser.add_argument("--n_col",
+ type=int,
+ default=5,
+ help="column number of output image grid")
+
+ parser.add_argument("--cpu",
+ dest="cpu",
+ action="store_true",
+ help="cpu mode.")
+
+ args = parser.parse_args()
+
+ if args.cpu:
+ paddle.set_device('cpu')
+
+ predictor = StyleGANv2Predictor(
+ output_path=args.output_path,
+ weight_path=args.weight_path,
+ model_type=args.model_type,
+ seed=args.seed,
+ size=args.size,
+ style_dim=args.style_dim,
+ n_mlp=args.n_mlp,
+ channel_multiplier=args.channel_multiplier
+ )
+ predictor.run(args.n_row, args.n_col)
diff --git a/docs/en_US/tutorials/pixel2style2pixel.md b/docs/en_US/tutorials/pixel2style2pixel.md
new file mode 100644
index 0000000000000000000000000000000000000000..2f37f9eb4d06b74c08365685da94d6186684a6fb
--- /dev/null
+++ b/docs/en_US/tutorials/pixel2style2pixel.md
@@ -0,0 +1,87 @@
+# Pixel2Style2Pixel
+
+## Pixel2Style2Pixel introduction
+
+The task of Pixel2Style2Pixel is image encoding. It mainly encodes an input image as the style vector of StyleGAN V2 and uses StyleGAN V2 as the decoder.
+
+
+
+
+
+Pixel2Style2Pixel uses a fairly large model to encode images, and encodes the image into the style vector space of StyleGAN V2, so that the image before encoding and the image after decoding have a strong correlation.
+
+Its main functions are:
+
+- Convert image to hidden codes
+- Turn face to face
+- Generate images based on sketches or segmentation results
+- Convert low-resolution images to high-definition images
+
+At present, only the models of portrait reconstruction and portrait cartoonization are realized in PaddleGAN.
+
+## How to use
+
+### Generate
+
+The user could use the following command to generate and select the local image as input:
+
+```
+cd applications/
+python -u tools/styleganv2.py \
+ --input_image \
+ --output_path \
+ --weight_path \
+ --model_type ffhq-inversion \
+ --seed 233 \
+ --size 1024 \
+ --style_dim 512 \
+ --n_mlp 8 \
+ --channel_multiplier 2 \
+ --cpu
+```
+
+**params:**
+- input_image: the input image file path
+- output_path: the directory where the generated images are stored
+- weight_path: pretrained model path
+- model_type: inner model type in PaddleGAN. If you use an existing model type, `weight_path` will have no effect.
+ Currently available: `ffhq-inversion`, `ffhq-toonify`
+- seed: random number seed
+- size: model parameters, output image resolution
+- style_dim: model parameters, dimensions of style z
+- n_mlp: model parameters, the number of multi-layer perception layers for style z
+- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures
+- cpu: whether to use cpu inference, if not, please remove it from the command
+
+### Train (TODO)
+
+In the future, training scripts will be added to facilitate users to train more types of Pixel2Style2Pixel image encoders.
+
+
+## Results
+
+Input portrait:
+
+
+
+
+
+Cropped portrait-Reconstructed portrait-Cartoonized portrait:
+
+
+
+## Reference
+
+```
+@article{richardson2020encoding,
+ title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
+ author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
+ journal={arXiv preprint arXiv:2008.00951},
+ year={2020}
+}
+
+```
diff --git a/docs/en_US/tutorials/styleganv2.md b/docs/en_US/tutorials/styleganv2.md
new file mode 100644
index 0000000000000000000000000000000000000000..dc742c391418ddf486196c2ae9b5aa1440175bfd
--- /dev/null
+++ b/docs/en_US/tutorials/styleganv2.md
@@ -0,0 +1,83 @@
+# StyleGAN V2
+
+## StyleGAN V2 introduction
+
+The task of StyleGAN V2 is image generation. Given a vector of a specific length, generate the image corresponding to the vector. It is an upgraded version of StyleGAN, which solves the problem of artifacts generated by StyleGAN.
+
+
+
+
+
+StyleGAN V2 can mix multi-level style vectors. Its core is adaptive style decoupling.
+
+Compared with StyleGAN, its main improvement is:
+
+- The quality of the generated image is significantly better (higher FID score, fewer artifacts)
+- Propose a new method to replace progressive training, with more perfect details such as teeth and eyes
+- Style mixing improved
+- Smoother interpolation
+- Train faster
+
+## How to use
+
+### Generate
+
+The user can generate different results by replacing the value of the seed or removing the seed. Use the following command to generate images:
+
+```
+cd applications/
+python -u tools/styleganv2.py \
+ --output_path \
+ --weight_path \
+ --model_type ffhq-config-f \
+ --seed 233 \
+ --size 1024 \
+ --style_dim 512 \
+ --n_mlp 8 \
+ --channel_multiplier 2 \
+ --n_row 3 \
+ --n_col 5 \
+ --cpu
+```
+
+**params:**
+- output_path: the directory where the generated images are stored
+- weight_path: pretrained model path
+- model_type: inner model type in PaddleGAN. If you use an existing model type, `weight_path` will have no effect.
+ Currently available: `ffhq-config-f`, `animeface-512`
+- seed: random number seed
+- size: model parameters, output image resolution
+- style_dim: model parameters, dimensions of style z
+- n_mlp: model parameters, the number of multi-layer perception layers for style z
+- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures
+- n_row: the number of rows of the sampled image
+- n_col: the number of columns of the sampled picture
+- cpu: whether to use cpu inference, if not, please remove it from the command
+
+### Train (TODO)
+
+In the future, training scripts will be added to facilitate users to train more types of StyleGAN V2 image generators.
+
+
+## Results
+
+Random Samples:
+
+![Samples](../../imgs/stylegan2-sample.png)
+
+Random Style Mixing:
+
+![Random Style Mixing](../../imgs/stylegan2-sample-mixing-0.png)
+
+
+## Reference
+
+```
+@inproceedings{Karras2019stylegan2,
+ title = {Analyzing and Improving the Image Quality of {StyleGAN}},
+ author = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
+ booktitle = {Proc. CVPR},
+ year = {2020}
+}
+
+```
diff --git a/docs/imgs/pSp-input-crop.png b/docs/imgs/pSp-input-crop.png
new file mode 100644
index 0000000000000000000000000000000000000000..93173c080f689a57fef46df4299dcdb2f112167e
Binary files /dev/null and b/docs/imgs/pSp-input-crop.png differ
diff --git a/docs/imgs/pSp-input.jpg b/docs/imgs/pSp-input.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..10c4d68bea9c04fd30f1173c7f2930c6bde79c89
Binary files /dev/null and b/docs/imgs/pSp-input.jpg differ
diff --git a/docs/imgs/pSp-inversion.png b/docs/imgs/pSp-inversion.png
new file mode 100644
index 0000000000000000000000000000000000000000..60fdf1525e62051bd314d9ccdbbb4cecc6a618ee
Binary files /dev/null and b/docs/imgs/pSp-inversion.png differ
diff --git a/docs/imgs/pSp-teaser.jpg b/docs/imgs/pSp-teaser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..277d3afc4086ce094c78e264cadd76df793209d9
Binary files /dev/null and b/docs/imgs/pSp-teaser.jpg differ
diff --git a/docs/imgs/pSp-toonify.png b/docs/imgs/pSp-toonify.png
new file mode 100644
index 0000000000000000000000000000000000000000..d2d3cd892133e8c73ecdc1dc5c2f3591e4e3a3aa
Binary files /dev/null and b/docs/imgs/pSp-toonify.png differ
diff --git a/docs/imgs/stylegan2-sample-mixing-0.png b/docs/imgs/stylegan2-sample-mixing-0.png
new file mode 100644
index 0000000000000000000000000000000000000000..699d8a212d001ec2b02859b4a9ae4aa18d916cee
Binary files /dev/null and b/docs/imgs/stylegan2-sample-mixing-0.png differ
diff --git a/docs/imgs/stylegan2-sample.png b/docs/imgs/stylegan2-sample.png
new file mode 100644
index 0000000000000000000000000000000000000000..cd620ec312cd75200ac85d1c38e1ce7e2c13fa54
Binary files /dev/null and b/docs/imgs/stylegan2-sample.png differ
diff --git a/docs/imgs/stylegan2-teaser-1024x256.png b/docs/imgs/stylegan2-teaser-1024x256.png
new file mode 100644
index 0000000000000000000000000000000000000000..bb16c5f5c8b615983b36b2446564e654cc7805c3
Binary files /dev/null and b/docs/imgs/stylegan2-teaser-1024x256.png differ
diff --git a/docs/zh_CN/tutorials/pixel2style2pixel.md b/docs/zh_CN/tutorials/pixel2style2pixel.md
new file mode 100644
index 0000000000000000000000000000000000000000..09319653d0146555b3c37ab454c8199704d9b8be
--- /dev/null
+++ b/docs/zh_CN/tutorials/pixel2style2pixel.md
@@ -0,0 +1,87 @@
+# Pixel2Style2Pixel
+
+## Pixel2Style2Pixel 原理
+
+Pixel2Style2Pixel 的任务是image encoding。它主要是将图像编码为StyleGAN V2的风格向量,将StyleGAN V2当作解码器。
+
+
+
+
+
+Pixel2Style2Pixel使用相当大的模型对图像进行编码,将图像编码到StyleGAN V2的风格向量空间中,使编码前的图像和解码后的图像具有强关联性。
+
+它的主要功能有:
+
+- 将图像转成隐藏编码
+- 将人脸转正
+- 根据草图或者分割结果生成图像
+- 将低分辨率图像转成高清图像
+
+目前在PaddleGAN中实现了人像重建和人像卡通化的模型。
+
+## 使用方法
+
+### 生成
+
+用户使用如下命令中进行生成,选择本地图像作为输入:
+
+```
+cd applications/
+python -u tools/styleganv2.py \
+ --input_image <替换为输入的图像路径> \
+ --output_path <替换为生成图片存放的文件夹> \
+ --weight_path <替换为你的预训练模型路径> \
+ --model_type ffhq-inversion \
+ --seed 233 \
+ --size 1024 \
+ --style_dim 512 \
+ --n_mlp 8 \
+ --channel_multiplier 2 \
+ --cpu
+```
+
+**参数说明:**
+- input_image: 输入的图像路径
+- output_path: 生成图片存放的文件夹
+- weight_path: 预训练模型路径
+- model_type: PaddleGAN内置模型类型,若输入PaddleGAN已存在的模型类型,`weight_path`将失效。
+ 当前可用: `ffhq-inversion`, `ffhq-toonify`
+- seed: 随机数种子
+- size: 模型参数,输出图片的分辨率
+- style_dim: 模型参数,风格z的维度
+- n_mlp: 模型参数,风格z所输入的多层感知层的层数
+- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量
+- cpu: 是否使用cpu推理,若不使用,请在命令中去除
+
+### 训练(TODO)
+
+未来还将添加训练脚本方便用户训练出更多类型的 Pixel2Style2Pixel 图像编码器。
+
+
+## 生成结果展示
+
+输入人像:
+
+
+
+
+
+裁剪人像-重建人像-卡通化人像:
+
+
+
+## 参考文献
+
+```
+@article{richardson2020encoding,
+ title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
+ author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
+ journal={arXiv preprint arXiv:2008.00951},
+ year={2020}
+}
+
+```
diff --git a/docs/zh_CN/tutorials/styleganv2.md b/docs/zh_CN/tutorials/styleganv2.md
new file mode 100644
index 0000000000000000000000000000000000000000..7ebab5e1ff14af2fdca8769515b40736491a6029
--- /dev/null
+++ b/docs/zh_CN/tutorials/styleganv2.md
@@ -0,0 +1,83 @@
+# StyleGAN V2
+
+## StyleGAN V2 原理
+
+StyleGAN V2 的任务是image generation,给定特定长度的向量,生成该向量对应的图像,是StyleGAN的升级版,解决了StyleGAN生成的伪像等问题。
+
+
+
+
+
+StyleGAN V2 可对多级风格向量进行混合。其内核是自适应的风格解耦。
+
+相对于StyleGAN,其主要改进为:
+
+- 生成的图像质量明显更好(FID分数更高、artifacts减少)
+- 提出替代渐进式训练的新方法,牙齿、眼睛等细节更完美
+- 改善了风格混合
+- 更平滑的插值
+- 训练速度更快
+
+## 使用方法
+
+### 生成
+
+用户使用如下命令中进行生成,可通过替换seed的值或去掉seed生成不同的结果:
+
+```
+cd applications/
+python -u tools/styleganv2.py \
+ --output_path <替换为生成图片存放的文件夹> \
+ --weight_path <替换为你的预训练模型路径> \
+ --model_type ffhq-config-f \
+ --seed 233 \
+ --size 1024 \
+ --style_dim 512 \
+ --n_mlp 8 \
+ --channel_multiplier 2 \
+ --n_row 3 \
+ --n_col 5 \
+ --cpu
+```
+
+**参数说明:**
+- output_path: 生成图片存放的文件夹
+- weight_path: 预训练模型路径
+- model_type: PaddleGAN内置模型类型,若输入PaddleGAN已存在的模型类型,`weight_path`将失效。
+ 当前可用: `ffhq-config-f`, `animeface-512`
+- seed: 随机数种子
+- size: 模型参数,输出图片的分辨率
+- style_dim: 模型参数,风格z的维度
+- n_mlp: 模型参数,风格z所输入的多层感知层的层数
+- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量
+- n_row: 采样的图片的行数
+- n_col: 采样的图片的列数
+- cpu: 是否使用cpu推理,若不使用,请在命令中去除
+
+### 训练(TODO)
+
+未来还将添加训练脚本方便用户训练出更多类型的 StyleGAN V2 图像生成器。
+
+
+## 生成结果展示
+
+随机采样结果:
+
+![随机采样结果](../../imgs/stylegan2-sample.png)
+
+随机风格插值结果:
+
+![随机风格插值结果](../../imgs/stylegan2-sample-mixing-0.png)
+
+
+## 参考文献
+
+```
+@inproceedings{Karras2019stylegan2,
+ title = {Analyzing and Improving the Image Quality of {StyleGAN}},
+ author = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
+ booktitle = {Proc. CVPR},
+ year = {2020}
+}
+
+```
diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py
index 8704748b74ce59476db0886e9f2e691ef2698dc9..a0aaaf0dc4444574fad45f30768f95b3d7d57af7 100644
--- a/ppgan/apps/__init__.py
+++ b/ppgan/apps/__init__.py
@@ -21,3 +21,5 @@ from .first_order_predictor import FirstOrderPredictor
from .face_parse_predictor import FaceParsePredictor
from .animegan_predictor import AnimeGANPredictor
from .midas_predictor import MiDaSPredictor
+from .styleganv2_predictor import StyleGANv2Predictor
+from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor
diff --git a/ppgan/apps/pixel2style2pixel_predictor.py b/ppgan/apps/pixel2style2pixel_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3722a9111cf6860fe49771f7fc5b83319b7f4ff
--- /dev/null
+++ b/ppgan/apps/pixel2style2pixel_predictor.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import scipy
+import random
+import numpy as np
+import paddle
+import paddle.vision.transforms as T
+import ppgan.faceutils as futils
+from .base_predictor import BasePredictor
+from ppgan.models.generators import Pixel2Style2Pixel
+from ppgan.utils.download import get_path_from_url
+from PIL import Image
+
+
+model_cfgs = {
+ 'ffhq-inversion': {
+ 'model_urls': 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-inversion.pdparams',
+ 'transform': T.Compose([
+ T.Resize((256, 256)),
+ T.Transpose(),
+ T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
+ ]),
+ 'size': 1024,
+ 'style_dim': 512,
+ 'n_mlp': 8,
+ 'channel_multiplier': 2
+ },
+ 'ffhq-toonify': {
+ 'model_urls': 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-toonify.pdparams',
+ 'transform': T.Compose([
+ T.Resize((256, 256)),
+ T.Transpose(),
+ T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
+ ]),
+ 'size': 1024,
+ 'style_dim': 512,
+ 'n_mlp': 8,
+ 'channel_multiplier': 2
+ },
+ 'default': {
+ 'transform': T.Compose([
+ T.Resize((256, 256)),
+ T.Transpose(),
+ T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
+ ])
+ }
+}
+
+
+def run_alignment(image_path):
+ img = Image.open(image_path).convert("RGB")
+ face = futils.dlib.detect(img)
+ if not face:
+ raise Exception('Could not find a face in the given image.')
+ face_on_image = face[0]
+ lm = futils.dlib.landmarks(img, face_on_image)
+ lm = np.array(lm)[:,::-1]
+ lm_eye_left = lm[36 : 42]
+ lm_eye_right = lm[42 : 48]
+ lm_mouth_outer = lm[48 : 60]
+
+ output_size = 1024
+ transform_size = 4096
+ enable_padding = True
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # Shrink.
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w, _ = img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
+ blur = qsize * 0.02
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ quad += pad[:2]
+
+ # Transform.
+ img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
+
+ return img
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+class Pixel2Style2PixelPredictor(BasePredictor):
+ def __init__(self,
+ output_path='output_dir',
+ weight_path=None,
+ model_type=None,
+ seed=None,
+ size=1024,
+ style_dim=512,
+ n_mlp=8,
+ channel_multiplier=2):
+ self.output_path = output_path
+
+ if weight_path is None and model_type != 'default':
+ if model_type in model_cfgs.keys():
+ weight_path = get_path_from_url(model_cfgs[model_type]['model_urls'])
+ size = model_cfgs[model_type].get('size', size)
+ style_dim = model_cfgs[model_type].get('style_dim', style_dim)
+ n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp)
+ channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier)
+ checkpoint = paddle.load(weight_path)
+ else:
+ raise ValueError('Predictor need a weight path or a pretrained model type')
+ else:
+ checkpoint = paddle.load(weight_path)
+
+ opts = checkpoint.pop('opts')
+ opts = AttrDict(opts)
+ opts['size'] = size
+ opts['style_dim'] = style_dim
+ opts['n_mlp'] = n_mlp
+ opts['channel_multiplier'] = channel_multiplier
+
+ self.generator = Pixel2Style2Pixel(opts)
+ self.generator.set_state_dict(checkpoint)
+ self.generator.eval()
+
+ if seed is not None:
+ paddle.seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+
+ self.model_type = 'default' if model_type is None else model_type
+
+ def run(self, image):
+ src_img = run_alignment(image)
+ src_img = np.asarray(src_img)
+ transformed_image = model_cfgs[self.model_type]['transform'](src_img)
+ dst_img = (self.generator(paddle.to_tensor(transformed_image[None, ...]))
+ * 0.5 + 0.5)[0].numpy() * 255
+ dst_img = dst_img.transpose((1, 2, 0))
+
+ os.makedirs(self.output_path, exist_ok=True)
+ save_src_path = os.path.join(self.output_path, 'src.png')
+ cv2.imwrite(save_src_path, cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR))
+ save_dst_path = os.path.join(self.output_path, 'dst.png')
+ cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR))
+
+ return src_img
diff --git a/ppgan/apps/styleganv2_predictor.py b/ppgan/apps/styleganv2_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9626967735d3ddf395c3f417bc0a92687f65339
--- /dev/null
+++ b/ppgan/apps/styleganv2_predictor.py
@@ -0,0 +1,148 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import random
+import numpy as np
+import paddle
+from .base_predictor import BasePredictor
+from ppgan.models.generators import StyleGANv2Generator
+from ppgan.utils.download import get_path_from_url
+from ppgan.utils.visual import make_grid, tensor2img, save_image
+
+
+model_cfgs = {
+ 'ffhq-config-f': {
+ 'model_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f.pdparams',
+ 'size': 1024,
+ 'style_dim': 512,
+ 'n_mlp': 8,
+ 'channel_multiplier': 2
+ },
+ 'animeface-512': {
+ 'model_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-animeface-512.pdparams',
+ 'size': 512,
+ 'style_dim': 512,
+ 'n_mlp': 8,
+ 'channel_multiplier': 2
+ }
+}
+
+
+@paddle.no_grad()
+def get_mean_style(generator):
+ mean_style = None
+
+ for i in range(10):
+ style = generator.mean_latent(1024)
+
+ if mean_style is None:
+ mean_style = style
+
+ else:
+ mean_style += style
+
+ mean_style /= 10
+ return mean_style
+
+
+@paddle.no_grad()
+def sample(generator, mean_style, n_sample):
+ image = generator(
+ [paddle.randn([n_sample, generator.style_dim])],
+ truncation=0.7,
+ truncation_latent=mean_style,
+ )[0]
+
+ return image
+
+
+@paddle.no_grad()
+def style_mixing(generator, mean_style, n_source, n_target):
+ source_code = paddle.randn([n_source, generator.style_dim])
+ target_code = paddle.randn([n_target, generator.style_dim])
+
+ resolution = 2 ** ((generator.n_latent + 2) // 2)
+
+ images = [paddle.ones([1, 3, resolution, resolution]) * -1]
+
+ source_image = generator(
+ [source_code], truncation_latent=mean_style, truncation=0.7
+ )[0]
+ target_image = generator(
+ [target_code], truncation_latent=mean_style, truncation=0.7
+ )[0]
+
+ images.append(source_image)
+
+ for i in range(n_target):
+ image = generator(
+ [target_code[i].unsqueeze(0).tile([n_source, 1]), source_code],
+ truncation_latent=mean_style,
+ truncation=0.7,
+ )[0]
+ images.append(target_image[i].unsqueeze(0))
+ images.append(image)
+
+ images = paddle.concat(images, 0)
+
+ return images
+
+
+class StyleGANv2Predictor(BasePredictor):
+ def __init__(self,
+ output_path='output_dir',
+ weight_path=None,
+ model_type=None,
+ seed=None,
+ size=1024,
+ style_dim=512,
+ n_mlp=8,
+ channel_multiplier=2):
+ self.output_path = output_path
+
+ if weight_path is None:
+ if model_type in model_cfgs.keys():
+ weight_path = get_path_from_url(model_cfgs[model_type]['model_urls'])
+ size = model_cfgs[model_type].get('size', size)
+ style_dim = model_cfgs[model_type].get('style_dim', style_dim)
+ n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp)
+ channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier)
+ checkpoint = paddle.load(weight_path)
+ else:
+ raise ValueError('Predictor need a weight path or a pretrained model type')
+ else:
+ checkpoint = paddle.load(weight_path)
+
+ self.generator = StyleGANv2Generator(size, style_dim, n_mlp, channel_multiplier)
+ self.generator.set_state_dict(checkpoint)
+ self.generator.eval()
+
+ if seed is not None:
+ paddle.seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+
+ def run(self, n_row=3, n_col=5):
+ os.makedirs(self.output_path, exist_ok=True)
+ mean_style = get_mean_style(self.generator)
+
+ img = sample(self.generator, mean_style, n_row * n_col)
+ save_image(tensor2img(make_grid(img, nrow=n_col)), f'{self.output_path}/sample.png')
+
+ for j in range(2):
+ img = style_mixing(self.generator, mean_style, n_col, n_row)
+ save_image(tensor2img(make_grid(
+ img, nrow=n_col + 1
+ )), f'{self.output_path}/sample_mixing_{j}.png')
diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py
index 41c23b5210ab737d6b31b0db2daec6d1636792b9..f7af297488ab8f97eaf1015d56019cd5e4abad03 100644
--- a/ppgan/models/discriminators/__init__.py
+++ b/ppgan/models/discriminators/__init__.py
@@ -17,3 +17,4 @@ from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification
from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator
from .discriminator_animegan import AnimeDiscriminator
+from .discriminator_styleganv2 import StyleGANv2Discriminator
diff --git a/ppgan/models/discriminators/discriminator_styleganv2.py b/ppgan/models/discriminators/discriminator_styleganv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a06e1f60927d02de8021343daf345a4bd78b66fa
--- /dev/null
+++ b/ppgan/models/discriminators/discriminator_styleganv2.py
@@ -0,0 +1,151 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from .builder import DISCRIMINATORS
+from ...modules.equalized import EqualLinear, EqualConv2D
+from ...modules.fused_act import FusedLeakyReLU
+from ...modules.upfirdn2d import Upfirdn2dBlur
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(
+ EqualConv2D(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=self.padding,
+ stride=stride,
+ bias=bias and not activate,
+ )
+ )
+
+ if activate:
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Layer):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
+
+ self.skip = ConvLayer(
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
+ )
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out + skip) / math.sqrt(2)
+
+ return out
+
+
+@DISCRIMINATORS.register()
+class StyleGANv2Discriminator(nn.Layer):
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ convs = [ConvLayer(3, channels[size], 1)]
+
+ log_size = int(math.log(size, 2))
+
+ in_channel = channels[size]
+
+ for i in range(log_size, 2, -1):
+ out_channel = channels[2 ** (i - 1)]
+
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
+
+ in_channel = out_channel
+
+ self.convs = nn.Sequential(*convs)
+
+ self.stddev_group = 4
+ self.stddev_feat = 1
+
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
+ self.final_linear = nn.Sequential(
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
+ EqualLinear(channels[4], 1),
+ )
+
+ def forward(self, input):
+ out = self.convs(input)
+
+ batch, channel, height, width = out.shape
+ group = min(batch, self.stddev_group)
+ stddev = out.reshape((
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
+ ))
+ stddev = paddle.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
+ stddev = stddev.tile((group, 1, height, width))
+ out = paddle.concat([out, stddev], 1)
+
+ out = self.final_conv(out)
+
+ out = out.reshape((batch, -1))
+ out = self.final_linear(out)
+
+ return out
diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py
index ad04c1cdf2d164eaa62b288a902aedf6594dcdb9..8c0feda68125f18d42e01f013aeb49795ab0a5e9 100644
--- a/ppgan/models/generators/__init__.py
+++ b/ppgan/models/generators/__init__.py
@@ -20,4 +20,6 @@ from .deep_conv import DeepConvGenerator, ConditionalDeepConvGenerator
from .resnet_ugatit import ResnetUGATITGenerator
from .dcgenerator import DCGenerator
from .generater_animegan import AnimeGenerator, AnimeGeneratorLite
-from .wav2lip import Wav2Lip
\ No newline at end of file
+from .wav2lip import Wav2Lip
+from .generator_styleganv2 import StyleGANv2Generator
+from .generator_pixel2style2pixel import Pixel2Style2Pixel
diff --git a/ppgan/models/generators/generator_pixel2style2pixel.py b/ppgan/models/generators/generator_pixel2style2pixel.py
new file mode 100644
index 0000000000000000000000000000000000000000..1651cc54c01b45df3a837d51e53a295f4a45b199
--- /dev/null
+++ b/ppgan/models/generators/generator_pixel2style2pixel.py
@@ -0,0 +1,384 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import numpy as np
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from collections import namedtuple
+
+from .builder import GENERATORS
+from .generator_styleganv2 import StyleGANv2Generator
+from ...modules.equalized import EqualLinear
+
+
+class Flatten(nn.Layer):
+ def forward(self, input):
+ return input.reshape((input.shape[0], -1))
+
+
+def l2_norm(input, axis=1):
+ norm = paddle.norm(input, 2, axis, True)
+ output = paddle.div(input, norm)
+ return output
+
+
+class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
+ """ A named tuple describing a ResNet block. """
+
+
+def get_block(in_channel, depth, num_units, stride=2):
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
+
+
+def get_blocks(num_layers):
+ if num_layers == 50:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=14),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 100:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=13),
+ get_block(in_channel=128, depth=256, num_units=30),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 152:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=8),
+ get_block(in_channel=128, depth=256, num_units=36),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ else:
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
+ return blocks
+
+
+class SEModule(nn.Layer):
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2D(1)
+ self.fc1 = nn.Conv2D(channels, channels // reduction, kernel_size=1, padding=0, bias_attr=False)
+ self.relu = nn.ReLU()
+ self.fc2 = nn.Conv2D(channels // reduction, channels, kernel_size=1, padding=0, bias_attr=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = self.avg_pool(x)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+ return module_input * x
+
+
+class BottleneckIR(nn.Layer):
+ def __init__(self, in_channel, depth, stride):
+ super(BottleneckIR, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = nn.MaxPool2D(1, stride)
+ else:
+ self.shortcut_layer = nn.Sequential(
+ nn.Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False),
+ nn.BatchNorm2D(depth)
+ )
+ self.res_layer = nn.Sequential(
+ nn.BatchNorm2D(in_channel),
+ nn.Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False), nn.PReLU(depth),
+ nn.Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False), nn.BatchNorm2D(depth)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+class BottleneckIRSE(nn.Layer):
+ def __init__(self, in_channel, depth, stride):
+ super(BottleneckIRSE, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = nn.MaxPool2D(1, stride)
+ else:
+ self.shortcut_layer = nn.Sequential(
+ nn.Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False),
+ nn.BatchNorm2D(depth)
+ )
+ self.res_layer = nn.Sequential(
+ nn.BatchNorm2D(in_channel),
+ nn.Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False),
+ nn.PReLU(depth),
+ nn.Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False),
+ nn.BatchNorm2D(depth),
+ SEModule(depth, 16)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+class GradualStyleBlock(nn.Layer):
+ def __init__(self, in_c, out_c, spatial):
+ super(GradualStyleBlock, self).__init__()
+ self.out_c = out_c
+ self.spatial = spatial
+ num_pools = int(np.log2(spatial))
+ modules = []
+ modules += [nn.Conv2D(in_c, out_c, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU()]
+ for i in range(num_pools - 1):
+ modules += [
+ nn.Conv2D(out_c, out_c, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU()
+ ]
+ self.convs = nn.Sequential(*modules)
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
+
+ def forward(self, x):
+ x = self.convs(x)
+ x = x.reshape((-1, self.out_c))
+ x = self.linear(x)
+ return x
+
+
+class GradualStyleEncoder(nn.Layer):
+ def __init__(self, num_layers, mode='ir', opts=None):
+ super(GradualStyleEncoder, self).__init__()
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = BottleneckIR
+ elif mode == 'ir_se':
+ unit_module = BottleneckIRSE
+ self.input_layer = nn.Sequential(nn.Conv2D(opts.input_nc, 64, (3, 3), 1, 1, bias_attr=False),
+ nn.BatchNorm2D(64),
+ nn.PReLU(64))
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = nn.Sequential(*modules)
+
+ self.styles = nn.LayerList()
+ self.style_count = 18
+ self.coarse_ind = 3
+ self.middle_ind = 7
+ for i in range(self.style_count):
+ if i < self.coarse_ind:
+ style = GradualStyleBlock(512, 512, 16)
+ elif i < self.middle_ind:
+ style = GradualStyleBlock(512, 512, 32)
+ else:
+ style = GradualStyleBlock(512, 512, 64)
+ self.styles.append(style)
+ self.latlayer1 = nn.Conv2D(256, 512, kernel_size=1, stride=1, padding=0)
+ self.latlayer2 = nn.Conv2D(128, 512, kernel_size=1, stride=1, padding=0)
+
+ def _upsample_add(self, x, y):
+ '''Upsample and add two feature maps.
+ Args:
+ x: (Tensor) top feature map to be upsampled.
+ y: (Tensor) lateral feature map.
+ Returns:
+ (Tensor) added feature map.
+ Note in Pypaddle, when input size is odd, the upsampled feature map
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
+ maybe not equal to the lateral feature map size.
+ e.g.
+ original input size: [N,_,15,15] ->
+ conv2d feature map size: [N,_,8,8] ->
+ upsampled feature map size: [N,_,16,16]
+ So we choose bilinear upsample which supports arbitrary output sizes.
+ '''
+ _, _, H, W = y.shape
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
+
+ def forward(self, x):
+ x = self.input_layer(x)
+
+ latents = []
+ modulelist = list(self.body._sub_layers.values())
+ for i, l in enumerate(modulelist):
+ x = l(x)
+ if i == 6:
+ c1 = x
+ elif i == 20:
+ c2 = x
+ elif i == 23:
+ c3 = x
+
+ for j in range(self.coarse_ind):
+ latents.append(self.styles[j](c3))
+
+ p2 = self._upsample_add(c3, self.latlayer1(c2))
+ for j in range(self.coarse_ind, self.middle_ind):
+ latents.append(self.styles[j](p2))
+
+ p1 = self._upsample_add(p2, self.latlayer2(c1))
+ for j in range(self.middle_ind, self.style_count):
+ latents.append(self.styles[j](p1))
+
+ out = paddle.stack(latents, 1)
+ return out
+
+
+class BackboneEncoderUsingLastLayerIntoW(nn.Layer):
+ def __init__(self, num_layers, mode='ir', opts=None):
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
+ print('Using BackboneEncoderUsingLastLayerIntoW')
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = BottleneckIR
+ elif mode == 'ir_se':
+ unit_module = BottleneckIRSE
+ self.input_layer = nn.Sequential(nn.Conv2D(opts.input_nc, 64, (3, 3), 1, 1, bias_attr=False),
+ nn.BatchNorm2D(64),
+ nn.PReLU(64))
+ self.output_pool = nn.AdaptiveAvgPool2D((1, 1))
+ self.linear = EqualLinear(512, 512, lr_mul=1)
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = nn.Sequential(*modules)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+ x = self.body(x)
+ x = self.output_pool(x)
+ x = x.reshape((-1, 512))
+ x = self.linear(x)
+ return x
+
+
+class BackboneEncoderUsingLastLayerIntoWPlus(nn.Layer):
+ def __init__(self, num_layers, mode='ir', opts=None):
+ super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
+ print('Using BackboneEncoderUsingLastLayerIntoWPlus')
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = BottleneckIR
+ elif mode == 'ir_se':
+ unit_module = BottleneckIRSE
+ self.input_layer = nn.Sequential(nn.Conv2D(opts.input_nc, 64, (3, 3), 1, 1, bias_attr=False),
+ nn.BatchNorm2D(64),
+ nn.PReLU(64))
+ self.output_layer_2 = nn.Sequential(nn.BatchNorm2D(512),
+ nn.AdaptiveAvgPool2D((7, 7)),
+ Flatten(),
+ nn.Linear(512 * 7 * 7, 512))
+ self.linear = EqualLinear(512, 512 * 18, lr_mul=1)
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = nn.Sequential(*modules)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+ x = self.body(x)
+ x = self.output_layer_2(x)
+ x = self.linear(x)
+ x = x.reshape((-1, 18, 512))
+ return x
+
+
+@GENERATORS.register()
+class Pixel2Style2Pixel(nn.Layer):
+ def __init__(self, opts):
+ super(Pixel2Style2Pixel, self).__init__()
+ self.set_opts(opts)
+ # Define architecture
+ self.encoder = self.set_encoder()
+ self.decoder = StyleGANv2Generator(opts.size, opts.style_dim, opts.n_mlp, opts.channel_multiplier)
+ self.face_pool = nn.AdaptiveAvgPool2D((256, 256))
+ self.style_dim = self.decoder.style_dim
+ self.n_latent = self.decoder.n_latent
+ if self.opts.start_from_latent_avg:
+ if self.opts.learn_in_w:
+ self.register_buffer('latent_avg', paddle.zeros([1, self.style_dim]))
+ else:
+ self.register_buffer('latent_avg', paddle.zeros([1, self.n_latent, self.style_dim]))
+
+ def set_encoder(self):
+ if self.opts.encoder_type == 'GradualStyleEncoder':
+ encoder = GradualStyleEncoder(50, 'ir_se', self.opts)
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
+ encoder = BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
+ encoder = BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
+ else:
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
+ return encoder
+
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
+ inject_latent=None, return_latents=False, alpha=None):
+ if input_code:
+ codes = x
+ else:
+ codes = self.encoder(x)
+ # normalize with respect to the center of an average face
+ if self.opts.start_from_latent_avg:
+ if self.opts.learn_in_w:
+ codes = codes + self.latent_avg.tile([codes.shape[0], 1])
+ else:
+ codes = codes + self.latent_avg.tile([codes.shape[0], 1, 1])
+
+
+ if latent_mask is not None:
+ for i in latent_mask:
+ if inject_latent is not None:
+ if alpha is not None:
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
+ else:
+ codes[:, i] = inject_latent[:, i]
+ else:
+ codes[:, i] = 0
+
+ input_is_latent = not input_code
+ images, result_latent = self.decoder([codes],
+ input_is_latent=input_is_latent,
+ randomize_noise=randomize_noise,
+ return_latents=return_latents)
+
+ if resize:
+ images = self.face_pool(images)
+
+ if return_latents:
+ return images, result_latent
+ else:
+ return images
+
+ def set_opts(self, opts):
+ self.opts = opts
diff --git a/ppgan/models/generators/generator_styleganv2.py b/ppgan/models/generators/generator_styleganv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c0ccbaaf4cfa969792523de6ff4439876e41c09
--- /dev/null
+++ b/ppgan/models/generators/generator_styleganv2.py
@@ -0,0 +1,395 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import random
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from .builder import GENERATORS
+from ...modules.equalized import EqualLinear
+from ...modules.fused_act import FusedLeakyReLU
+from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur
+
+
+class PixelNorm(nn.Layer):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input * paddle.rsqrt(paddle.mean(input ** 2, 1, keepdim=True) + 1e-8)
+
+
+class ModulatedConv2D(nn.Layer):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ demodulate=True,
+ upsample=False,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ ):
+ super().__init__()
+
+ self.eps = 1e-8
+ self.kernel_size = kernel_size
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.upsample = upsample
+ self.downsample = downsample
+
+ if upsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2 + 1
+
+ self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
+
+ fan_in = in_channel * kernel_size ** 2
+ self.scale = 1 / math.sqrt(fan_in)
+ self.padding = kernel_size // 2
+
+ self.weight = self.create_parameter(
+ (1, out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal()
+ )
+
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
+
+ self.demodulate = demodulate
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
+ f"upsample={self.upsample}, downsample={self.downsample})"
+ )
+
+ def forward(self, input, style):
+ batch, in_channel, height, width = input.shape
+
+ style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
+ weight = self.scale * self.weight * style
+
+ if self.demodulate:
+ demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
+
+ weight = weight.reshape((
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ ))
+
+ if self.upsample:
+ input = input.reshape((1, batch * in_channel, height, width))
+ weight = weight.reshape((
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ ))
+ weight = weight.transpose((0, 2, 1, 3, 4)).reshape((
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
+ ))
+ out = F.conv2d_transpose(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.reshape((batch, self.out_channel, height, width))
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ _, _, height, width = input.shape
+ input = input.reshape((1, batch * in_channel, height, width))
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.reshape((batch, self.out_channel, height, width))
+
+ else:
+ input = input.reshape((1, batch * in_channel, height, width))
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
+ _, _, height, width = out.shape
+ out = out.reshape((batch, self.out_channel, height, width))
+
+ return out
+
+
+class NoiseInjection(nn.Layer):
+ def __init__(self):
+ super().__init__()
+
+ self.weight = self.create_parameter((1,), default_initializer=nn.initializer.Constant(0.0))
+
+ def forward(self, image, noise=None):
+ if noise is None:
+ batch, _, height, width = image.shape
+ noise = paddle.randn((batch, 1, height, width))
+
+ return image + self.weight * noise
+
+
+class ConstantInput(nn.Layer):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = self.create_parameter((1, channel, size, size), default_initializer=nn.initializer.Normal())
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.tile((batch, 1, 1, 1))
+
+ return out
+
+
+class StyledConv(nn.Layer):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ demodulate=True,
+ ):
+ super().__init__()
+
+ self.conv = ModulatedConv2D(
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=upsample,
+ blur_kernel=blur_kernel,
+ demodulate=demodulate,
+ )
+
+ self.noise = NoiseInjection()
+ self.activate = FusedLeakyReLU(out_channel)
+
+ def forward(self, input, style, noise=None):
+ out = self.conv(input, style)
+ out = self.noise(out, noise=noise)
+ out = self.activate(out)
+
+ return out
+
+
+class ToRGB(nn.Layer):
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ if upsample:
+ self.upsample = Upfirdn2dUpsample(blur_kernel)
+
+ self.conv = ModulatedConv2D(in_channel, 3, 1, style_dim, demodulate=False)
+ self.bias = self.create_parameter((1, 3, 1, 1), nn.initializer.Constant(0.0))
+
+ def forward(self, input, style, skip=None):
+ out = self.conv(input, style)
+ out = out + self.bias
+
+ if skip is not None:
+ skip = self.upsample(skip)
+
+ out = out + skip
+
+ return out
+
+
+@GENERATORS.register()
+class StyleGANv2Generator(nn.Layer):
+ def __init__(
+ self,
+ size,
+ style_dim,
+ n_mlp,
+ channel_multiplier=2,
+ blur_kernel=[1, 3, 3, 1],
+ lr_mlp=0.01,
+ ):
+ super().__init__()
+
+ self.size = size
+
+ self.style_dim = style_dim
+
+ layers = [PixelNorm()]
+
+ for i in range(n_mlp):
+ layers.append(
+ EqualLinear(
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
+ )
+ )
+
+ self.style = nn.Sequential(*layers)
+
+ self.channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ self.input = ConstantInput(self.channels[4])
+ self.conv1 = StyledConv(
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
+ )
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
+
+ self.log_size = int(math.log(size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+
+ self.convs = nn.LayerList()
+ self.upsamples = nn.LayerList()
+ self.to_rgbs = nn.LayerList()
+ self.noises = nn.Layer()
+
+ in_channel = self.channels[4]
+
+ for layer_idx in range(self.num_layers):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2 ** res, 2 ** res]
+ self.noises.register_buffer(f"noise_{layer_idx}", paddle.randn(shape))
+
+ for i in range(3, self.log_size + 1):
+ out_channel = self.channels[2 ** i]
+
+ self.convs.append(
+ StyledConv(
+ in_channel,
+ out_channel,
+ 3,
+ style_dim,
+ upsample=True,
+ blur_kernel=blur_kernel,
+ )
+ )
+
+ self.convs.append(
+ StyledConv(
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
+ )
+ )
+
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
+
+ in_channel = out_channel
+
+ self.n_latent = self.log_size * 2 - 2
+
+ def make_noise(self):
+ noises = [paddle.randn((1, 1, 2 ** 2, 2 ** 2))]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(paddle.randn((1, 1, 2 ** i, 2 ** i)))
+
+ return noises
+
+ def mean_latent(self, n_latent):
+ latent_in = paddle.randn((
+ n_latent, self.style_dim
+ ))
+ latent = self.style(latent_in).mean(0, keepdim=True)
+
+ return latent
+
+ def get_latent(self, input):
+ return self.style(input)
+
+ def forward(
+ self,
+ styles,
+ return_latents=False,
+ inject_index=None,
+ truncation=1,
+ truncation_latent=None,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ ):
+ if not input_is_latent:
+ styles = [self.style(s) for s in styles]
+
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers
+ else:
+ noise = [
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
+ ]
+
+ if truncation < 1:
+ style_t = []
+
+ for style in styles:
+ style_t.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+
+ styles = style_t
+
+ if len(styles) < 2:
+ inject_index = self.n_latent
+
+ if styles[0].ndim < 3:
+ latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
+
+ else:
+ latent = styles[0]
+
+ else:
+ if inject_index is None:
+ inject_index = random.randint(1, self.n_latent - 1)
+
+ latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
+ latent2 = styles[1].unsqueeze(1).tile((1, self.n_latent - inject_index, 1))
+
+ latent = paddle.concat([latent, latent2], 1)
+
+ out = self.input(latent)
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
+
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+
+ else:
+ return image, None
diff --git a/ppgan/modules/equalized.py b/ppgan/modules/equalized.py
new file mode 100644
index 0000000000000000000000000000000000000000..7280ab0e212f7309f2125e19e83cca59b096f31e
--- /dev/null
+++ b/ppgan/modules/equalized.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from .fused_act import fused_leaky_relu
+
+
+class EqualConv2D(nn.Layer):
+ """This convolutional layer class stabilizes the learning rate changes of its parameters.
+ Equalizing learning rate keeps the weights in the network at a similar scale during training.
+ """
+ def __init__(
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
+ ):
+ super().__init__()
+
+ self.weight = self.create_parameter(
+ (out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal()
+ )
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = self.create_parameter((out_channel,), nn.initializer.Constant(0.0))
+
+ else:
+ self.bias = None
+
+ def forward(self, input):
+ out = F.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
+ )
+
+
+class EqualLinear(nn.Layer):
+ """This linear layer class stabilizes the learning rate changes of its parameters.
+ Equalizing learning rate keeps the weights in the network at a similar scale during training.
+ """
+ def __init__(
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
+ ):
+ super().__init__()
+
+ self.weight = self.create_parameter((in_dim, out_dim), default_initializer=nn.initializer.Normal())
+ self.weight[:] = (self.weight / lr_mul).detach()
+
+ if bias:
+ self.bias = self.create_parameter((out_dim,), nn.initializer.Constant(bias_init))
+
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+
+ else:
+ out = F.linear(
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})"
+ )
diff --git a/ppgan/modules/fused_act.py b/ppgan/modules/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..8723af36c2799c5f3e82d6d4b2baccf70a347cce
--- /dev/null
+++ b/ppgan/modules/fused_act.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class FusedLeakyReLU(nn.Layer):
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
+ super().__init__()
+
+ if bias:
+ self.bias = self.create_parameter((channel,), default_initializer=nn.initializer.Constant(0.0))
+
+ else:
+ self.bias = None
+
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
+ if bias is not None:
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
+ return (
+ F.leaky_relu(
+ input + bias.reshape((1, bias.shape[0], *rest_dim)), negative_slope=0.2
+ )
+ * scale
+ )
+
+ else:
+ return F.leaky_relu(input, negative_slope=0.2) * scale
diff --git a/ppgan/modules/upfirdn2d.py b/ppgan/modules/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..856378a62dd14c613787fd9ecead77d036b27467
--- /dev/null
+++ b/ppgan/modules/upfirdn2d.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+def upfirdn2d_native(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape((-1, in_h, in_w, 1))
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.reshape((-1, in_h, 1, in_w, 1, minor))
+ out = out.transpose((0,1,3,5,2,4))
+ out = out.reshape((-1,1,1,1))
+ out = F.pad(out, [0, up_x - 1, 0, up_y - 1])
+ out = out.reshape((-1, in_h, in_w, minor, up_y, up_x))
+ out = out.transpose((0,3,1,4,2,5))
+ out = out.reshape((-1, minor, in_h * up_y, in_w * up_x))
+
+ out = F.pad(
+ out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
+ )
+ out = out[
+ :,:,
+ max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0),
+ ]
+
+ out = out.reshape((
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
+ ))
+ w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w))
+ out = F.conv2d(out, w)
+ out = out.reshape((
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ ))
+ out = out.transpose((0, 2, 3, 1))
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.reshape((-1, channel, out_h, out_w))
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ out = upfirdn2d_native(
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
+ )
+
+ return out
+
+
+def make_kernel(k):
+ k = paddle.to_tensor(k, dtype='float32')
+
+ if k.ndim == 1:
+ k = k.unsqueeze(0) * k.unsqueeze(1)
+
+ k /= k.sum()
+
+ return k
+
+
+class Upfirdn2dUpsample(nn.Layer):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel) * (factor ** 2)
+ self.register_buffer("kernel", kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+
+ return out
+
+
+class Upfirdn2dDownsample(nn.Layer):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel)
+ self.register_buffer("kernel", kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+
+ return out
+
+
+class Upfirdn2dBlur(nn.Layer):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer("kernel", kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
+
+ return out