From 9a31bdf95d7395ab97cd6cb1b896a54759ecdf27 Mon Sep 17 00:00:00 2001 From: bitcjm <1274375400@qq.com> Date: Wed, 18 May 2022 15:44:55 +0800 Subject: [PATCH] add the model file and (#634) * add the model files and tipc of GPEN * add the model files and tipc of GPEN --- applications/tools/gpen.py | 93 ++++ configs/gpen_256_ffhq.yaml | 77 +++ docs/en_US/tutorials/gpen.md | 202 ++++++++ docs/zh_CN/tutorials/gpen.md | 205 ++++++++ ppgan/apps/__init__.py | 1 + ppgan/apps/gpen_predictor.py | 140 ++++++ ppgan/datasets/__init__.py | 1 + ppgan/datasets/gpen_dataset.py | 401 ++++++++++++++++ ppgan/models/__init__.py | 1 + ppgan/models/criterions/IDLoss/helpers.py | 141 ++++++ ppgan/models/criterions/IDLoss/id_loss.py | 79 +++ ppgan/models/criterions/IDLoss/model_irse.py | 67 +++ ppgan/models/criterions/__init__.py | 1 + ppgan/models/discriminators/__init__.py | 2 +- .../discriminator_styleganv2.py | 71 +++ ppgan/models/generators/__init__.py | 1 + ppgan/models/generators/generator_gpen.py | 453 ++++++++++++++++++ ppgan/models/generators/gpen.py | 86 ++-- ppgan/models/gpen_model.py | 199 ++++++++ ppgan/modules/equalized.py | 49 ++ test_tipc/configs/GPEN/train_infer_python.txt | 51 ++ test_tipc/prepare.sh | 8 + tools/export_model.py | 4 +- tools/inference.py | 42 +- 24 files changed, 2324 insertions(+), 51 deletions(-) create mode 100644 applications/tools/gpen.py create mode 100644 configs/gpen_256_ffhq.yaml create mode 100644 docs/en_US/tutorials/gpen.md create mode 100644 docs/zh_CN/tutorials/gpen.md create mode 100644 ppgan/apps/gpen_predictor.py create mode 100644 ppgan/datasets/gpen_dataset.py create mode 100644 ppgan/models/criterions/IDLoss/helpers.py create mode 100644 ppgan/models/criterions/IDLoss/id_loss.py create mode 100644 ppgan/models/criterions/IDLoss/model_irse.py create mode 100644 ppgan/models/generators/generator_gpen.py create mode 100644 ppgan/models/gpen_model.py create mode 100644 test_tipc/configs/GPEN/train_infer_python.txt diff --git a/applications/tools/gpen.py b/applications/tools/gpen.py new file mode 100644 index 0000000..d77cd1c --- /dev/null +++ b/applications/tools/gpen.py @@ -0,0 +1,93 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append(".") +import argparse +import paddle +from ppgan.apps import GPENPredictor + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_path", + type=str, + default='output_dir', + help="path to output image dir") + + parser.add_argument("--weight_path", + type=str, + default=None, + help="path to model checkpoint path") + + parser.add_argument("--test_img", + type=str, + default='data/gpen/lite_data/15006.png', + help="path of test image") + + parser.add_argument("--model_type", + type=str, + default=None, + help="type of model for loading pretrained model") + + parser.add_argument("--seed", + type=int, + default=None, + help="sample random seed for model's image generation") + + parser.add_argument("--size", + type=int, + default=256, + help="resolution of output image") + + parser.add_argument("--style_dim", + type=int, + default=512, + help="number of style dimension") + + parser.add_argument("--n_mlp", + type=int, + default=8, + help="number of mlp layer depth") + + parser.add_argument("--channel_multiplier", + type=int, + default=1, + help="number of channel multiplier") + + parser.add_argument("--narrow", + type=float, + default=0.5, + help="number of channel narrow") + + parser.add_argument("--cpu", + dest="cpu", + action="store_true", + help="cpu mode.") + + args = parser.parse_args() + + if args.cpu: + paddle.set_device('cpu') + + predictor = GPENPredictor(output_path=args.output_path, + weight_path=args.weight_path, + model_type=args.model_type, + seed=args.seed, + size=args.size, + style_dim=args.style_dim, + n_mlp=args.n_mlp, + narrow=args.narrow, + channel_multiplier=args.channel_multiplier) + predictor.run(args.test_img) diff --git a/configs/gpen_256_ffhq.yaml b/configs/gpen_256_ffhq.yaml new file mode 100644 index 0000000..a5e0936 --- /dev/null +++ b/configs/gpen_256_ffhq.yaml @@ -0,0 +1,77 @@ +total_iters: 200000 +output_dir: output_dir +find_unused_parameters: True + + +model: + name: GPENModel + generator: + name: GPEN + size: 256 + style_dim: 512 + n_mlp: 8 + channel_multiplier: 1 + narrow: 0.5 + discriminator: + name: GPENDiscriminator + size: 256 + channel_multiplier: 1 + narrow: 0.5 + + +export_model: + - {name: 'g_ema', inputs_num: 1} + +dataset: + train: + name: GPENDataset + dataroot: data/ffhq/images256x256/ + num_workers: 0 + batch_size: 2 #1gpus + size: 256 + + test: + name: GPENDataset + dataroot: data/ffhq/images256x256/ + num_workers: 0 + batch_size: 1 + size: 256 + amount: 100 + + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: 0.002 + periods: [500000, 500000, 500000, 500000] + restart_weights: [1, 1, 1, 1] + eta_min: 0.002 + + +optimizer: + optimG: + name: Adam + net_names: + - netG + beta1: 0.9 + beta2: 0.99 + optimD: + name: Adam + net_names: + - netD + beta1: 0.9 + beta2: 0.99 + +log_config: + interval: 100 + visiual_interval: 500 + +snapshot_config: + interval: 5000 + +validate: + interval: 5000 + save_img: false + metrics: + fid: + name: FID + batch_size: 1 diff --git a/docs/en_US/tutorials/gpen.md b/docs/en_US/tutorials/gpen.md new file mode 100644 index 0000000..0384a44 --- /dev/null +++ b/docs/en_US/tutorials/gpen.md @@ -0,0 +1,202 @@ +English | [Chinese](../../zh_CN/tutorials/gpen.md) + +## GPEN Blind Face Restoration Model + + +## 1、Introduction + +The GPEN model is a blind face restoration model. The author embeds the decoder of StyleGAN V2 proposed by the previous model as the decoder of GPEN; reconstructs a simple encoder with DNN to provide input for the decoder. In this way, while retaining the excellent performance of the StyleGAN V2 decoder, the function of the model is changed from image style conversion to blind face restoration. The overall structure of the model is shown in the following figure: + +![img](https://user-images.githubusercontent.com/23252220/168281766-a0972bd3-243e-4fc7-baa5-e458ef0946ce.jpg) + +For a more detailed introduction to the model, and refer to the repo, you can view the following AI Studio project [link]([GPEN Blind Face Repair Model Reproduction - Paddle AI Studio (baidu.com)](https://aistudio.baidu.com/ The latest version of aistudio/projectdetail/3936241?contributionType=1)). + + + + +## 2、Ready to work + +### 2.1 Dataset Preparation + +The GPEN model training set is the classic FFHQ face data set, with a total of 70,000 high-resolution 1024 x 1024 high-resolution face pictures, and the test set is the CELEBA-HQ data set, with a total of 2,000 high-resolution face pictures. For details, please refer to **Dataset URL:** [FFHQ](https://github.com/NVlabs/ffhq-dataset), [CELEBA-HQ](https://github.com/tkarras/progressive_growing_of_gans). The specific download links are given below: + +**Original dataset download address:** + +**FFHQ :** https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open + +**CELEBA-HQ:** https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs?resourcekey=0-arAVTUfW9KRhN-irJchVKQ&usp=sharing + + + +Since the original FFHQ dataset is too large, you can also download the 256-resolution FFHQ dataset from the following link: + +https://paddlegan.bj.bcebos.com/datasets/images256x256.tar + + + +**After downloading, the file organization is as follows** + +``` +|-- data/GPEN + |-- ffhq/images256x256/ + |-- 00000 + |-- 00000.png + |-- 00001.png + |-- ...... + |-- 00999.png + |-- 01000 + |-- ...... + |-- ...... + |-- 69000 + |-- ...... + |-- 69999.png + |-- test + |-- 2000张png图片 +``` + +Please modify the dataroot parameters of dataset train and test in the configs/gpen_256_ffhq.yaml configuration file to your training set and test set path. + + + +### 2.2 Model preparation + +**Model parameter file and training log download address:** + +link:https://paddlegan.bj.bcebos.com/models/gpen.zip + + +Download the model parameters and test images from the link and put them in the data/ folder in the project root directory. The specific file structure is as follows: + + +``` +data/gpen/weights + |-- model_ir_se50.pdparams + |-- weight_pretrain.pdparams +data/gpen/lite_data +``` + + + +## 3、Start using + +### 3.1 model training + +Enter the following code in the console to start training: + + ```shell + python tools/main.py -c configs/gpen_256_ffhq.yaml + ``` + +The model only supports single-card training. + +Model training needs to use paddle2.3 and above, and wait for paddle to implement the second-order operator related functions of elementwise_pow. The paddle2.2.2 version can run normally, but the model cannot be successfully trained because some loss functions will calculate the wrong gradient. . If an error is reported during training, training is not supported for the time being. You can skip the training part and directly use the provided model parameters for testing. Model evaluation and testing can use paddle2.2.2 and above. + + + +### 3.2 Model evaluation + +When evaluating the model, enter the following code in the console, using the downloaded model parameters mentioned above: + + ```shell +python tools/main.py -c configs/gpen_256_ffhq.yaml -o dataset.test.amount=2000 --load data/gpen/weights/weight_pretrain.pdparams --evaluate-only + ``` + +If you want to test on your own provided model, please modify the path after --load . + + + +### 3.3 Model prediction + +#### 3.3.1 Export generator weights + +After training, you need to use ``tools/extract_weight.py`` to extract the weights of the generator from the trained model (including the generator and discriminator) for inference to `applications/tools/gpen.py` to achieve Various applications of the GPEN model. Enter the following command to extract the weights of the generator: + +```bash +python tools/extract_weight.py data/gpen/weights/weight_pretrain.pdparams --net-name g_ema --output data/gpen/weights/g_ema.pdparams +``` + + + +#### 3.3.2 Process a single image + +After extracting the weights of the generator, enter the following command to test the images under the --test_img path. Modifying the --seed parameter can generate different degraded images to show richer effects. You can modify the path after --test_img to any image you want to test. If no weight is provided after the --weight_path parameter, the trained model weights will be automatically downloaded for testing. + +```bash +python applications/tools/gpen.py --test_img data/gpen/lite_data/15006.png --seed=100 --weight_path data/gpen/weights/g_ema.pdparams --model_type gpen-ffhq-256 +``` + +The following are the sample images and the corresponding inpainted images, from left to right, the degraded image, the generated image, and the original clear image: + +

+ +An example output is as follows: + + +``` +result saved in : output_dir/gpen_predict.png + FID: 92.11730631094356 + PSNR:19.014782083825743 +``` + + + +## 4. Tipc + +### 4.1 Export the inference model + +```bash +python tools/export_model.py -c configs/gpen_256_ffhq.yaml --inputs_size=1,3,256,256 --load data/gpen/weights/weight_pretrain.pdparams +``` + +The above command will generate the model structure file `gpenmodel_g_ema.pdmodel` and model weight files `gpenmodel_g_ema.pdiparams` and `gpenmodel_g_ema.pdiparams.info` files required for prediction, which are stored in the `inference_model/` directory. You can also modify the parameters after --load to the model parameter file you want to test. + + + +### 4.2 Inference with a prediction engine + +```bash +python tools/inference.py --model_type GPEN --seed 100 -c configs/gpen_256_ffhq.yaml -o dataset.test.dataroot="./data/gpen/lite_data/" --output_path test_tipc/output/ --model_path inference_model/gpenmodel_g_ema +``` + +At the end of the inference, the repaired image generated by the model will be saved in the test_tipc/output/GPEN directory by default, and the FID value obtained by the test will be output in test_tipc/output/GPEN/metric.txt. + + +The default output is as follows: + +``` +Metric fid: 187.0158 +``` + +Note: Since the operation of degrading high-definition pictures has a certain degree of randomness, the results of each test will be different. In order to ensure that the test results are consistent, here I fixed the random seed, so that the same degradation operation is performed on the image for each test. + + + +### 4.3 Call the script to complete the training and push test in two steps + +To invoke the `lite_train_lite_infer` mode of the foot test base training prediction function, run: + +```shell +# Corrected format of sh file +sed -i 's/\r//' test_tipc/prepare.sh +sed -i 's/\r//' test_tipc/test_train_inference_python.sh +sed -i 's/\r//' test_tipc/common_func.sh +# prepare data +bash test_tipc/prepare.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer' +# run the test +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer' +``` + + + +## 5、References + +``` +@misc{2021GAN, + title={GAN Prior Embedded Network for Blind Face Restoration in the Wild}, + author={ Yang, T. and Ren, P. and Xie, X. and Zhang, L. }, + year={2021}, + archivePrefix={CVPR}, + primaryClass={cs.CV} +} +``` + diff --git a/docs/zh_CN/tutorials/gpen.md b/docs/zh_CN/tutorials/gpen.md new file mode 100644 index 0000000..e303e6c --- /dev/null +++ b/docs/zh_CN/tutorials/gpen.md @@ -0,0 +1,205 @@ +[English](../../en_US/tutorials/gpen.md) | 中文 + +## GPEN 盲人脸修复模型 + + +## 1、简介 + +GPEN模型是一个盲人脸修复模型。作者将前人提出的 StyleGAN V2 的解码器嵌入模型,作为GPEN的解码器;用DNN重新构建了一种简单的编码器,为解码器提供输入。这样模型在保留了 StyleGAN V2 解码器优秀的性能的基础上,将模型的功能由图像风格转换变为了盲人脸修复。模型的总体结构如下图所示: + +![img](https://user-images.githubusercontent.com/23252220/168281766-a0972bd3-243e-4fc7-baa5-e458ef0946ce.jpg) + +对模型更详细的介绍,和参考repo可查看以下AI Studio项目[链接]([GPEN盲人脸修复模型复现 - 飞桨AI Studio (baidu.com)](https://aistudio.baidu.com/aistudio/projectdetail/3936241?contributionType=1))的最新版本。 + + + + +## 2、准备工作 + +### 2.1 数据集准备 + +GPEN模型训练集是经典的FFHQ人脸数据集,共70000张1024 x 1024高分辨率的清晰人脸图片,测试集是CELEBA-HQ数据集,共2000张高分辨率人脸图片。详细信息可以参考**数据集网址:** [FFHQ](https://github.com/NVlabs/ffhq-dataset) ,[CELEBA-HQ](https://github.com/tkarras/progressive_growing_of_gans) 。以下给出了具体的下载链接: + +**原数据集下载地址:** + +**FFHQ :** https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open + +**CELEBA-HQ:** https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs?resourcekey=0-arAVTUfW9KRhN-irJchVKQ&usp=sharing + + + +由于FFHQ原数据集过大,也可以从以下链接下载256分辨率的FFHQ数据集: + +https://paddlegan.bj.bcebos.com/datasets/images256x256.tar + + + +**下载后,文件参考组织形式如下** + +``` +|-- data/GPEN + |-- ffhq/images256x256/ + |-- 00000 + |-- 00000.png + |-- 00001.png + |-- ...... + |-- 00999.png + |-- 01000 + |-- ...... + |-- ...... + |-- 69000 + |-- ...... + |-- 69999.png + |-- test + |-- 2000张png图片 +``` + +请修改configs/gpen_256_ffhq.yaml配置文件中dataset的train和test的dataroot参数为你的训练集和测试集路径。 + + + +### 2.2 模型准备 + +**模型参数文件及训练日志下载地址:** + +链接:https://paddlegan.bj.bcebos.com/models/gpen.zip + + +从链接中下载模型参数和测试图片,并放到项目根目录下的data/文件夹下,具体文件结构如下所示: + +**文件结构** + + +``` +data/gpen/weights + |-- model_ir_se50.pdparams #计算id_loss需要加载的facenet的模型参数文件 + |-- weight_pretrain.pdparams #256分辨率的包含生成器和判别器的模型参数文件,其中只有生成器的参数是训练好的参数,参 #数文件的格式与3.1训练过程中保存的参数文件格式相同。3.2、3.3.1、4.1也需要用到该参数文件 +data/gpen/lite_data +``` + + + +## 3、开始使用 + +### 3.1 模型训练 + +在控制台输入以下代码,开始训练: + + ```shell + python tools/main.py -c configs/gpen_256_ffhq.yaml + ``` + +模型只支持单卡训练。 + +模型训练需使用paddle2.3及以上版本,且需等paddle实现elementwise_pow 的二阶算子相关功能,使用paddle2.2.2版本能正常运行,但因部分损失函数会求出错误梯度,导致模型无法训练成功。如训练时报错则暂不支持进行训练,可跳过训练部分,直接使用提供的模型参数进行测试。模型评估和测试使用paddle2.2.2及以上版本即可。 + + + +### 3.2 模型评估 + +对模型进行评估时,在控制台输入以下代码,下面代码中使用上面提到的下载的模型参数: + + ```shell +python tools/main.py -c configs/gpen_256_ffhq.yaml -o dataset.test.amount=2000 --load data/gpen/weights/weight_pretrain.pdparams --evaluate-only + ``` + +如果要在自己提供的模型上进行测试,请修改 --load 后面的路径。 + + + +### 3.3 模型预测 + +#### 3.3.1 导出生成器权重 + +训练结束后,需要使用 ``tools/extract_weight.py`` 来从训练模型(包含了生成器和判别器)中提取生成器的权重来给`applications/tools/gpen.py`进行推理,以实现GPEN模型的各种应用。输入以下命令来提取生成器的权重: + +```bash +python tools/extract_weight.py data/gpen/weights/weight_pretrain.pdparams --net-name g_ema --output data/gpen/weights/g_ema.pdparams +``` + + + +#### 3.3.2 对单张图像进行处理 + +提取完生成器的权重后,输入以下命令可对--test_img路径下图片进行测试。修改--seed参数,可生成不同的退化图像,展示出更丰富的效果。可修改--test_img后的路径为你想测试的任意图片。如--weight_path参数后不提供权重,则会自动下载训练好的模型权重进行测试。 + +```bash +python applications/tools/gpen.py --test_img data/gpen/lite_data/15006.png --seed=100 --weight_path data/gpen/weights/g_ema.pdparams --model_type gpen-ffhq-256 +``` + +以下是样例图片和对应的修复图像,从左到右依次是退化图像、生成的图像和原始清晰图像: + +

+ + + + +输出示例如下: + +``` +result saved in : output_dir/gpen_predict.png + FID: 92.11730631094356 + PSNR:19.014782083825743 +``` + + + +## 4. Tipc + +### 4.1 导出inference模型 + +```bash +python tools/export_model.py -c configs/gpen_256_ffhq.yaml --inputs_size=1,3,256,256 --load data/gpen/weights/weight_pretrain.pdparams +``` + +上述命令将生成预测所需的模型结构文件`gpenmodel_g_ema.pdmodel`和模型权重文件`gpenmodel_g_ema.pdiparams`以及`gpenmodel_g_ema.pdiparams.info`文件,均存放在`inference_model/`目录下。也可以修改--load 后的参数为你想测试的模型参数文件。 + + + +### 4.2 使用预测引擎推理 + +```bash +python tools/inference.py --model_type GPEN --seed 100 -c configs/gpen_256_ffhq.yaml -o dataset.test.dataroot="./data/gpen/lite_data/" --output_path test_tipc/output/ --model_path inference_model/gpenmodel_g_ema +``` + +推理结束会默认保存下模型生成的修复图像在test_tipc/output/GPEN目录下,并载test_tipc/output/GPEN/metric.txt中输出测试得到的FID值。 + + +默认输出如下: + +``` +Metric fid: 187.0158 +``` + +注:由于对高清图片进行退化的操作具有一定的随机性,所以每次测试的结果都会有所不同。为了保证测试结果一致,在这里我固定了随机种子,使每次测试时对图片都进行相同的退化操作。 + + + +### 4.3 调用脚本两步完成训推一体测试 + +测试基本训练预测功能的`lite_train_lite_infer`模式,运行: + +```shell +# 修正脚本文件格式 +sed -i 's/\r//' test_tipc/prepare.sh +sed -i 's/\r//' test_tipc/test_train_inference_python.sh +sed -i 's/\r//' test_tipc/common_func.sh +# 准备数据 +bash test_tipc/prepare.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer' +# 运行测试 +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer' +``` + + + +## 5、参考文献 + +``` +@misc{2021GAN, + title={GAN Prior Embedded Network for Blind Face Restoration in the Wild}, + author={ Yang, T. and Ren, P. and Xie, X. and Zhang, L. }, + year={2021}, + archivePrefix={CVPR}, + primaryClass={cs.CV} +} +``` diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index 15d7d25..008b15c 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -35,3 +35,4 @@ from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \ BasiVSRPlusPlusPredictor, IconVSRPredictor, \ PPMSVSRLargePredictor) from .singan_predictor import SinGANPredictor +from .gpen_predictor import GPENPredictor diff --git a/ppgan/apps/gpen_predictor.py b/ppgan/apps/gpen_predictor.py new file mode 100644 index 0000000..95f9811 --- /dev/null +++ b/ppgan/apps/gpen_predictor.py @@ -0,0 +1,140 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +import random +import numpy as np +import paddle +import sys + +sys.path.append(".") +from .base_predictor import BasePredictor +from ppgan.datasets.gpen_dataset import GFPGAN_degradation +from ppgan.models.generators import GPEN +from ppgan.metrics.fid import FID +from ppgan.utils.download import get_path_from_url +import cv2 + +import warnings + +model_cfgs = { + 'gpen-ffhq-256': { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/gpen-ffhq-256-generator.pdparams', + 'size': 256, + 'style_dim': 512, + 'n_mlp': 8, + 'channel_multiplier': 1, + 'narrow': 0.5 + } +} + + +def psnr(pred, gt): + pred = paddle.clip(pred, min=0, max=1) + gt = paddle.clip(gt, min=0, max=1) + imdff = np.asarray(pred - gt) + rmse = math.sqrt(np.mean(imdff**2)) + if rmse == 0: + return 100 + return 20 * math.log10(1.0 / rmse) + + +def data_loader(path, size=256): + degrader = GFPGAN_degradation() + + img_gt = cv2.imread(path, cv2.IMREAD_COLOR) + + img_gt = cv2.resize(img_gt, (size, size), interpolation=cv2.INTER_NEAREST) + + img_gt = img_gt.astype(np.float32) / 255. + img_gt, img_lq = degrader.degrade_process(img_gt) + + img_gt = (paddle.to_tensor(img_gt) - 0.5) / 0.5 + img_lq = (paddle.to_tensor(img_lq) - 0.5) / 0.5 + + img_gt = img_gt.transpose([2, 0, 1]).flip(0).unsqueeze(0) + img_lq = img_lq.transpose([2, 0, 1]).flip(0).unsqueeze(0) + + return np.array(img_lq).astype('float32'), np.array(img_gt).astype( + 'float32') + + +class GPENPredictor(BasePredictor): + + def __init__(self, + output_path='output_dir', + weight_path=None, + model_type=None, + seed=100, + size=256, + style_dim=512, + n_mlp=8, + channel_multiplier=1, + narrow=0.5): + self.output_path = output_path + self.size = size + if weight_path is None: + if model_type in model_cfgs.keys(): + weight_path = get_path_from_url( + model_cfgs[model_type]['model_urls']) + size = model_cfgs[model_type].get('size', size) + style_dim = model_cfgs[model_type].get('style_dim', style_dim) + n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp) + channel_multiplier = model_cfgs[model_type].get( + 'channel_multiplier', channel_multiplier) + narrow = model_cfgs[model_type].get('narrow', narrow) + checkpoint = paddle.load(weight_path) + else: + raise ValueError( + 'Predictor need a weight path or a pretrained model type') + else: + checkpoint = paddle.load(weight_path) + + warnings.filterwarnings("always") + self.generator = GPEN(size, style_dim, n_mlp, channel_multiplier, + narrow) + self.generator.set_state_dict(checkpoint) + self.generator.eval() + + if seed is not None: + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + def run(self, img_path): + os.makedirs(self.output_path, exist_ok=True) + input_array, target_array = data_loader(img_path, self.size) + input_tensor = paddle.to_tensor(input_array) + target_tensor = paddle.to_tensor(target_array) + + FID_model = FID(use_GPU=True) + + with paddle.no_grad(): + output, _ = self.generator(input_tensor) + psnr_score = psnr(target_tensor, output) + FID_model.update(output, target_tensor) + fid_score = FID_model.accumulate() + + input_tensor = input_tensor.transpose([0, 2, 3, 1]) + target_tensor = target_tensor.transpose([0, 2, 3, 1]) + output = output.transpose([0, 2, 3, 1]) + sample_result = paddle.concat( + (input_tensor[0], output[0], target_tensor[0]), 1) + sample = cv2.cvtColor((sample_result.numpy() + 1) / 2 * 255, + cv2.COLOR_RGB2BGR) + file_name = self.output_path + '/gpen_predict.png' + cv2.imwrite(file_name, sample) + print(f"result saved in : {file_name}") + print(f"\tFID: {fid_score}\n\tPSNR:{psnr_score}") diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index 2137193..b0c2014 100755 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -30,3 +30,4 @@ from .vsr_vimeo90k_dataset import VSRVimeo90KDataset from .vsr_folder_dataset import VSRFolderDataset from .photopen_dataset import PhotoPenDataset from .empty_dataset import EmptyDataset +from .gpen_dataset import GPENDataset diff --git a/ppgan/datasets/gpen_dataset.py b/ppgan/datasets/gpen_dataset.py new file mode 100644 index 0000000..b4584f8 --- /dev/null +++ b/ppgan/datasets/gpen_dataset.py @@ -0,0 +1,401 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import logging +import os +import numpy as np +import paddle +from paddle.io import Dataset +import cv2 + +from .builder import DATASETS + +import math +import random + +logger = logging.getLogger(__name__) + + +def generate_gaussian_noise(img, sigma=10, gray_noise=False): + """Generate Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255. + noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) + else: + noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255. + return noise + + +def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0): + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_gaussian_noise(img, sigma, gray_noise) + + +def random_add_gaussian_noise(img, + sigma_range=(0, 1.0), + gray_prob=0, + clip=True, + rounds=False): + noise = random_generate_gaussian_noise(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def add_jpg_compression(img, quality=90): + """Add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality (float): JPG compression quality. 0 for lowest quality, 100 for + best quality. Default: 90. + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + img = np.clip(img, 0, 1) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + _, encimg = cv2.imencode('.jpg', img * 255., encode_param) + img = np.float32(cv2.imdecode(encimg, 1)) / 255. + return img + + +def random_add_jpg_compression(img, quality_range=(90, 100)): + """Randomly add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality_range (tuple[float] | list[float]): JPG compression quality + range. 0 for lowest quality, 100 for best quality. + Default: (90, 100). + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + quality = int(np.random.uniform(quality_range[0], quality_range[1])) + return add_jpg_compression(img, quality) + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=(0.6, 5), + sigma_y_range=(0.6, 5), + rotation_range=(-math.pi, math.pi), + betag_range=(0.5, 8), + betap_range=(0.5, 8), + noise_range=None): + """Randomly generate mixed kernels. + + Args: + kernel_list (tuple): a list name of kernel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', + 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each + kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range, + isotropic=True) + elif kernel_type == 'aniso': + kernel = random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range, + isotropic=False) + return kernel + + +def random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + kernel = bivariate_Gaussian(kernel_size, + sigma_x, + sigma_y, + rotation, + isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], + noise_range[1], + size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_Gaussian(kernel_size, + sig_x, + sig_y, + theta, + grid=None, + isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + + Args: + kernel_size (int): + + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), + yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +class GFPGAN_degradation(object): + + def __init__(self): + self.kernel_list = ['iso', 'aniso'] + self.kernel_prob = [0.5, 0.5] + self.blur_kernel_size = 41 + self.blur_sigma = [0.1, 10] + self.downsample_range = [0.8, 8] + self.noise_range = [0, 20] + self.jpeg_range = [60, 100] + self.gray_prob = 0.2 + self.color_jitter_prob = 0.0 + self.color_jitter_pt_prob = 0.0 + self.shift = 20 / 255. + + def degrade_process(self, img_gt): + if random.random() > 0.5: + img_gt = cv2.flip(img_gt, 1) + + h, w = img_gt.shape[:2] + + # random color jitter + if np.random.uniform() < self.color_jitter_prob: + jitter_val = np.random.uniform(-self.shift, self.shift, + 3).astype(np.float32) + img_gt = img_gt + jitter_val + img_gt = np.clip(img_gt, 0, 1) + + # random grayscale + if np.random.uniform() < self.gray_prob: + img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) + img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) + + # ------------------------ generate lq image ------------------------ # + # blur + kernel = random_mixed_kernels(self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + noise_range=None) + img_lq = cv2.filter2D(img_gt, -1, kernel) + # downsample + scale = np.random.uniform(self.downsample_range[0], + self.downsample_range[1]) + img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), + interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range is not None: + img_lq = random_add_gaussian_noise(img_lq, self.noise_range) + # jpeg compression + if self.jpeg_range is not None: + img_lq = random_add_jpg_compression(img_lq, self.jpeg_range) + + # round and clip + img_lq = np.clip((img_lq * 255.0).round(), 0, 255) / 255. + + # resize to original size + img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) + + return img_gt, img_lq + + +@DATASETS.register() +class GPENDataset(Dataset): + """ + coco2017 dataset for LapStyle model + """ + + def __init__(self, dataroot, size=256, amount=-1): + super(GPENDataset, self).__init__() + self.size = size + self.HQ_imgs = sorted(glob.glob(os.path.join(dataroot, + '*/*.*g')))[:amount] + self.length = len(self.HQ_imgs) + if self.length == 0: + self.HQ_imgs = sorted(glob.glob(os.path.join(dataroot, + '*.*g')))[:amount] + self.length = len(self.HQ_imgs) + print(self.length) + self.degrader = GFPGAN_degradation() + + def __len__(self): + return self.length + + def __getitem__(self, index): + """Get training sample + + return: + ci: content image with shape [C,W,H], + si: style image with shape [C,W,H], + ci_path: str + """ + img_gt = cv2.imread(self.HQ_imgs[index], cv2.IMREAD_COLOR) + img_gt = cv2.resize(img_gt, (self.size, self.size), + interpolation=cv2.INTER_AREA) + + # BFR degradation + img_gt = img_gt.astype(np.float32) / 255. + img_gt, img_lq = self.degrader.degrade_process(img_gt) + + img_gt = (paddle.to_tensor(img_gt) - 0.5) / 0.5 + img_lq = (paddle.to_tensor(img_lq) - 0.5) / 0.5 + + img_gt = img_gt.transpose([2, 0, 1]).flip(0) + img_lq = img_lq.transpose([2, 0, 1]).flip(0) + + return np.array(img_lq).astype('float32'), np.array(img_gt).astype( + 'float32') diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 6f38794..65331da 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -37,3 +37,4 @@ from .msvsr_model import MultiStageVSRModel from .singan_model import SinGANModel from .rcan_model import RCANModel from .prenet_model import PReNetModel +from .gpen_model import GPENModel diff --git a/ppgan/models/criterions/IDLoss/helpers.py b/ppgan/models/criterions/IDLoss/helpers.py new file mode 100644 index 0000000..68f7571 --- /dev/null +++ b/ppgan/models/criterions/IDLoss/helpers.py @@ -0,0 +1,141 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import paddle +import paddle.nn as nn + + +class Flatten(nn.Layer): + + def forward(self, input): + return paddle.reshape(input, [input.shape[0], -1]) + + +def l2_norm(input, axis=1): + norm = paddle.norm(input, 2, axis, True) + output = paddle.divide(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride) + ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError( + "Invalid number of layers: {}. Must be one of [50, 100, 152]". + format(num_layers)) + return blocks + + +class SEModule(nn.Layer): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.fc1 = nn.Conv2D(channels, + channels // reduction, + kernel_size=1, + padding=0, + bias_attr=False) + self.relu = nn.ReLU() + self.fc2 = nn.Conv2D(channels // reduction, + channels, + kernel_size=1, + padding=0, + bias_attr=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(nn.Layer): + + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = nn.MaxPool2D(1, stride) + else: + self.shortcut_layer = nn.Sequential( + nn.Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False), + nn.BatchNorm2D(depth)) + self.res_layer = nn.Sequential( + nn.BatchNorm2D(in_channel), + nn.Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False), + nn.PReLU(depth), + nn.Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False), + nn.BatchNorm2D(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(nn.Layer): + + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = nn.MaxPool2D(1, stride) + else: + self.shortcut_layer = nn.Sequential( + nn.Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False), + nn.BatchNorm2D(depth)) + self.res_layer = nn.Sequential( + nn.BatchNorm2D(in_channel), + nn.Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False), + nn.PReLU(depth), + nn.Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False), + nn.BatchNorm2D(depth), SEModule(depth, 16)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/ppgan/models/criterions/IDLoss/id_loss.py b/ppgan/models/criterions/IDLoss/id_loss.py new file mode 100644 index 0000000..99b5b91 --- /dev/null +++ b/ppgan/models/criterions/IDLoss/id_loss.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +from .model_irse import Backbone +from paddle.vision.transforms import Resize +from ..builder import CRITERIONS +from ppgan.utils.download import get_path_from_url + +model_cfgs = { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/model_ir_se50.pdparams', +} + + +@CRITERIONS.register() +class IDLoss(paddle.nn.Layer): + + def __init__(self, base_dir='./'): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, + num_layers=50, + drop_ratio=0.6, + mode='ir_se') + + facenet_weights_path = os.path.join(base_dir, 'data/gpen/weights', + 'model_ir_se50.pdparams') + + if not os.path.isfile(facenet_weights_path): + facenet_weights_path = get_path_from_url(model_cfgs['model_urls']) + + self.facenet.load_dict(paddle.load(facenet_weights_path)) + + self.face_pool = paddle.nn.AdaptiveAvgPool2D((112, 112)) + self.facenet.eval() + + def extract_feats(self, x): + _, _, h, w = x.shape + assert h == w + ss = h // 256 + x = x[:, :, 35 * ss:-33 * ss, 32 * ss:-36 * ss] + transform = Resize(size=(112, 112)) + + for num in range(x.shape[0]): + mid_feats = transform(x[num]).unsqueeze(0) + if num == 0: + x_feats = mid_feats + else: + x_feats = paddle.concat([x_feats, mid_feats], axis=0) + + x_feats = self.facenet(x_feats) + return x_feats + + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + y_feats = self.extract_feats(y) + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + loss += 1 - diff_target + count += 1 + + return loss / count diff --git a/ppgan/models/criterions/IDLoss/model_irse.py b/ppgan/models/criterions/IDLoss/model_irse.py new file mode 100644 index 0000000..b0479df --- /dev/null +++ b/ppgan/models/criterions/IDLoss/model_irse.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(nn.Layer): + + def __init__(self, + input_size, + num_layers, + mode='ir', + drop_ratio=0.4, + affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, + 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = paddle.nn.Sequential( + nn.Conv2D(3, 64, (3, 3), 1, 1, bias_attr=False), nn.BatchNorm2D(64), + nn.PReLU(64)) + if input_size == 112: + self.output_layer = nn.Sequential(nn.BatchNorm2D(512), + nn.Dropout(drop_ratio), Flatten(), + nn.Linear(512 * 7 * 7, 512), + nn.BatchNorm1D(512)) + else: + self.output_layer = nn.Sequential(nn.BatchNorm2D(512), + nn.Dropout(drop_ratio), Flatten(), + nn.Linear(512 * 14 * 14, 512), + nn.BatchNorm1D(512)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = nn.Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) diff --git a/ppgan/models/criterions/__init__.py b/ppgan/models/criterions/__init__.py index e7b0ac1..b9bc3ec 100644 --- a/ppgan/models/criterions/__init__.py +++ b/ppgan/models/criterions/__init__.py @@ -9,3 +9,4 @@ from .gradient_penalty import GradientPenalty from .builder import build_criterion from .ssim import SSIM +from .IDLoss.id_loss import IDLoss diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index f371047..bacedac 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -17,7 +17,7 @@ from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification from .discriminator_ugatit import UGATITDiscriminator from .dcdiscriminator import DCDiscriminator from .discriminator_animegan import AnimeDiscriminator -from .discriminator_styleganv2 import StyleGANv2Discriminator +from .discriminator_styleganv2 import StyleGANv2Discriminator, GPENDiscriminator from .syncnet import SyncNetColor from .wav2lip_disc_qual import Wav2LipDiscQual from .discriminator_starganv2 import StarGANv2Discriminator diff --git a/ppgan/models/discriminators/discriminator_styleganv2.py b/ppgan/models/discriminators/discriminator_styleganv2.py index 8acea70..80d6e5b 100644 --- a/ppgan/models/discriminators/discriminator_styleganv2.py +++ b/ppgan/models/discriminators/discriminator_styleganv2.py @@ -28,6 +28,7 @@ from ...modules.upfirdn2d import Upfirdn2dBlur class ConvLayer(nn.Sequential): + def __init__( self, in_channel, @@ -72,6 +73,7 @@ class ConvLayer(nn.Sequential): class ResBlock(nn.Layer): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): super().__init__() @@ -112,6 +114,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None): @DISCRIMINATORS.register() class StyleGANv2Discriminator(nn.Layer): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): super().__init__() @@ -171,3 +174,71 @@ class StyleGANv2Discriminator(nn.Layer): out = self.final_linear(out) return out + + +@DISCRIMINATORS.register() +class GPENDiscriminator(nn.Layer): + + def __init__(self, + size, + channel_multiplier=1, + narrow=0.5, + blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2**(i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, + channels[4], + activation="fused_lrelu"), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.reshape((group, -1, self.stddev_feat, + channel // self.stddev_feat, height, width)) + stddev = paddle.sqrt(var(stddev, 0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2) + stddev = stddev.tile((group, 1, height, width)) + out = paddle.concat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.reshape((batch, -1)) + out = self.final_linear(out) + + return out diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 56572a7..630e105 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -41,3 +41,4 @@ from .msvsr import MSVSR from .generator_singan import SinGANGenerator from .rcan import RCAN from .prenet import PReNet +from .gpen import GPEN diff --git a/ppgan/models/generators/generator_gpen.py b/ppgan/models/generators/generator_gpen.py new file mode 100644 index 0000000..de5119e --- /dev/null +++ b/ppgan/models/generators/generator_gpen.py @@ -0,0 +1,453 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# code was heavily based on https://github.com/rosinality/stylegan2-pytorch +# MIT License +# Copyright (c) 2019 Kim Seonghyeon + +import math +import random +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ppgan.modules.equalized import EqualLinear_gpen as EqualLinear +from ppgan.modules.fused_act import FusedLeakyReLU +from ppgan.modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur + + +class PixelNorm(nn.Layer): + + def __init__(self): + super().__init__() + + def forward(self, inputs): + return inputs * paddle.rsqrt( + paddle.mean(inputs * inputs, 1, keepdim=True) + 1e-8) + + +class ModulatedConv2D(nn.Layer): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Upfirdn2dBlur(blur_kernel, + pad=(pad0, pad1), + upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * (kernel_size * kernel_size) + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = self.create_parameter( + (1, out_channel, in_channel, kernel_size, kernel_size), + default_initializer=nn.initializer.Normal()) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " + f"upsample={self.upsample}, downsample={self.downsample})") + + def forward(self, inputs, style): + batch, in_channel, height, width = inputs.shape + + style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1)) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1)) + + weight = weight.reshape((batch * self.out_channel, in_channel, + self.kernel_size, self.kernel_size)) + + if self.upsample: + inputs = inputs.reshape((1, batch * in_channel, height, width)) + weight = weight.reshape((batch, self.out_channel, in_channel, + self.kernel_size, self.kernel_size)) + weight = weight.transpose((0, 2, 1, 3, 4)).reshape( + (batch * in_channel, self.out_channel, self.kernel_size, + self.kernel_size)) + out = F.conv2d_transpose(inputs, + weight, + padding=0, + stride=2, + groups=batch) + _, _, height, width = out.shape + out = out.reshape((batch, self.out_channel, height, width)) + out = self.blur(out) + + elif self.downsample: + inputs = self.blur(inputs) + _, _, height, width = inputs.shape + inputs = inputs.reshape((1, batch * in_channel, height, width)) + out = F.conv2d(inputs, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.reshape((batch, self.out_channel, height, width)) + + else: + inputs = inputs.reshape((1, batch * in_channel, height, width)) + out = F.conv2d(inputs, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.reshape((batch, self.out_channel, height, width)) + + return out + + +class NoiseInjection(nn.Layer): + + def __init__(self, is_concat=False): + super().__init__() + + self.weight = self.create_parameter( + (1, ), default_initializer=nn.initializer.Constant(0.0)) + self.is_concat = is_concat + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = paddle.randn((batch, 1, height, width)) + if self.is_concat: + return paddle.concat([image, self.weight * noise], axis=1) + else: + return image + self.weight * noise + + +class ConstantInput(nn.Layer): + + def __init__(self, channel, size=4): + super().__init__() + + self.input = self.create_parameter( + (1, channel, size, size), + default_initializer=nn.initializer.Normal()) + + def forward(self, inputs): + batch = inputs.shape[0] + out = self.input.tile((batch, 1, 1, 1)) + + return out + + +class StyledConv(nn.Layer): + + def __init__(self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + is_concat=False): + super().__init__() + + self.conv = ModulatedConv2D( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection(is_concat=is_concat) + self.activate = FusedLeakyReLU(out_channel * + 2 if is_concat else out_channel) + + def forward(self, inputs, style, noise=None): + out = self.conv(inputs, style) + out = self.noise(out, noise=noise) + out = self.activate(out) + + return out + + +class ToRGB(nn.Layer): + + def __init__(self, + in_channel, + style_dim, + upsample=True, + blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upfirdn2dUpsample(blur_kernel) + + self.conv = ModulatedConv2D(in_channel, + 3, + 1, + style_dim, + demodulate=False) + self.bias = self.create_parameter((1, 3, 1, 1), + nn.initializer.Constant(0.0)) + + def forward(self, inputs, style, skip=None): + out = self.conv(inputs, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class StyleGANv2Generator(nn.Layer): + + def __init__(self, + size, + style_dim, + n_mlp, + channel_multiplier=1, + narrow=0.5, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + is_concat=True): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear(style_dim, + style_dim, + lr_mul=lr_mlp, + activation="fused_lrelu")) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv(self.channels[4], + self.channels[4], + 3, + style_dim, + blur_kernel=blur_kernel, + is_concat=is_concat) + self.to_rgb1 = ToRGB(self.channels[4] * + 2 if is_concat else self.channels[4], + style_dim, + upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.LayerList() + self.upsamples = nn.LayerList() + self.to_rgbs = nn.LayerList() + self.noises = nn.Layer() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2**res, 2**res] + self.noises.register_buffer(f"noise_{layer_idx}", + paddle.randn(shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2**i] + + self.convs.append( + StyledConv( + in_channel * 2 if is_concat else in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + is_concat=is_concat, + )) + + self.convs.append( + StyledConv(out_channel * 2 if is_concat else out_channel, + out_channel, + 3, + style_dim, + blur_kernel=blur_kernel, + is_concat=is_concat)) + + self.to_rgbs.append( + ToRGB(out_channel * 2 if is_concat else out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + self.is_concat = is_concat + + def make_noise(self): + noises = [paddle.randn((1, 1, 2**2, 2**2))] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(paddle.randn((1, 1, 2**i, 2**i))) + + return noises + + def mean_latent(self, n_latent): + latent_in = paddle.randn((n_latent, self.style_dim)) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, inputs): + return self.style(inputs) + + def get_mean_style(self): + mean_style = None + with paddle.no_grad(): + for i in range(10): + style = self.mean_latent(1024) + if mean_style is None: + mean_style = style + else: + mean_style += style + + mean_style /= 10 + return mean_style + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + ''' + noise = [None] * (2 * (self.log_size - 2) + 1) + ''' + noise = [] + batch = styles[0].shape[0] + for i in range(self.n_mlp + 1): + size = 2**(i + 2) + noise.append( + paddle.create_parameter( + [batch, self.channels[size], size, size], + dtype='float32', + attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0), + trainable=True))) + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append(truncation_latent + truncation * + (style - truncation_latent)) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + latent = styles[0].unsqueeze(1) + latent = paddle.tile(latent, repeat_times=[1, inject_index, 1]) + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = paddle.tile(styles[0].unsqueeze(1), + repeat_times=[1, inject_index, 1]) + latent2 = paddle.tile( + styles[1].unsqueeze(1), + repeat_times=[1, self.n_latent - inject_index, 1]) + + latent = paddle.concat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], + self.convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None diff --git a/ppgan/models/generators/gpen.py b/ppgan/models/generators/gpen.py index df72662..34a7913 100644 --- a/ppgan/models/generators/gpen.py +++ b/ppgan/models/generators/gpen.py @@ -12,60 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License. -# code was heavily based on https://github.com/yangxy/GPEN +# code was heavily based on code was heavily based on https://github.com/yangxy/GPEN -import paddle +import itertools import paddle.nn as nn import math -from ppgan.models.generators import StyleGANv2Generator +from ppgan.models.generators.builder import GENERATORS +from ppgan.modules.equalized import EqualLinear_gpen as EqualLinear +from ppgan.models.generators.generator_gpen import StyleGANv2Generator from ppgan.models.discriminators.discriminator_styleganv2 import ConvLayer -from ppgan.modules.equalized import EqualLinear + +@GENERATORS.register() class GPEN(nn.Layer): + def __init__( self, size, style_dim, n_mlp, channel_multiplier=2, + narrow=1, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, + is_concat=True, ): super(GPEN, self).__init__() channels = { - 4: 512, - 8: 512, - 16: 512, - 32: 512, - 64: 256 * channel_multiplier, - 128: 128 * channel_multiplier, - 256: 64 * channel_multiplier, - 512: 32 * channel_multiplier, - 1024: 16 * channel_multiplier, + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) } - self.log_size = int(math.log(size, 2)) - self.generator = StyleGANv2Generator(size, - style_dim, - n_mlp, - channel_multiplier=channel_multiplier, - blur_kernel=blur_kernel, - lr_mlp=lr_mlp, - is_concat=True) - + self.generator = StyleGANv2Generator( + size, + style_dim, + n_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + blur_kernel=blur_kernel, + lr_mlp=lr_mlp, + is_concat=is_concat) + conv = [ConvLayer(3, channels[size], 1)] self.ecd0 = nn.Sequential(*conv) in_channel = channels[size] - self.names = ['ecd%d'%i for i in range(self.log_size-1)] + self.names = ['ecd%d' % i for i in range(self.log_size - 1)] for i in range(self.log_size, 2, -1): - out_channel = channels[2 ** (i - 1)] - conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] - setattr(self, self.names[self.log_size-i+1], nn.Sequential(*conv)) + out_channel = channels[2**(i - 1)] + conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] + setattr(self, self.names[self.log_size - i + 1], + nn.Sequential(*conv)) in_channel = out_channel - self.final_linear = nn.Sequential(EqualLinear(channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, + style_dim, + activation='fused_lrelu')) - def forward(self, + def forward( + self, inputs, return_latents=False, inject_index=None, @@ -74,15 +87,20 @@ class GPEN(nn.Layer): input_is_latent=False, ): noise = [] - for i in range(self.log_size-1): + for i in range(self.log_size - 1): ecd = getattr(self, self.names[i]) inputs = ecd(inputs) noise.append(inputs) inputs = inputs.reshape([inputs.shape[0], -1]) outs = self.final_linear(inputs) - outs = self.generator([outs], return_latents, inject_index, truncation, - truncation_latent, input_is_latent, - noise=noise[::-1]) + noise = list( + itertools.chain.from_iterable( + itertools.repeat(x, 2) for x in noise))[::-1] + outs = self.generator([outs], + return_latents, + inject_index, + truncation, + truncation_latent, + input_is_latent, + noise=noise[1:]) return outs - - diff --git a/ppgan/models/gpen_model.py b/ppgan/models/gpen_model.py new file mode 100644 index 0000000..aa2f585 --- /dev/null +++ b/ppgan/models/gpen_model.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from .base_model import BaseModel + +from .builder import MODELS +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from ..modules.init import init_weights + +from .criterions.IDLoss.id_loss import IDLoss +from paddle.nn import functional as F +from paddle import autograd +import math + + +def d_logistic_loss(real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + +def d_r1_loss(real_pred, real_img): + grad_real, = autograd.grad(outputs=real_pred.sum(), + inputs=real_img, + create_graph=True) + grad_penalty = grad_real.pow(2).reshape([grad_real.shape[0], + -1]).sum(1).mean() + + return grad_penalty + + +def g_nonsaturating_loss(fake_pred, + loss_funcs=None, + fake_img=None, + real_img=None, + input_img=None): + smooth_l1_loss, id_loss = loss_funcs + + loss = F.softplus(-fake_pred).mean() + loss_l1 = smooth_l1_loss(fake_img, real_img) + loss_id = id_loss(fake_img, real_img, input_img) + loss += 1.0 * loss_l1 + 1.0 * loss_id + + return loss + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = paddle.randn(fake_img.shape) / math.sqrt( + fake_img.shape[2] * fake_img.shape[3]) + grad, = autograd.grad(outputs=(fake_img * noise).sum(), + inputs=latents, + create_graph=True) + path_lengths = paddle.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - + mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_mean.detach(), path_lengths + + +@MODELS.register() +class GPENModel(BaseModel): + """ This class implements the gpen model. + + """ + + def __init__(self, generator, discriminator=None, direction='a2b'): + + super(GPENModel, self).__init__() + + self.direction = direction + # define networks (both generator and discriminator) + self.nets['netG'] = build_generator(generator) + self.nets['g_ema'] = build_generator(generator) + self.nets['g_ema'].eval() + + if discriminator: + self.nets['netD'] = build_discriminator(discriminator) + + self.accum = 0.5**(32 / (10 * 1000)) + self.mean_path_length = 0 + + self.gan_criterions = [] + self.gan_criterions.append(paddle.nn.SmoothL1Loss()) + self.gan_criterions.append(IDLoss()) + self.current_iter = 0 + + def setup_input(self, input): + + self.degraded_img = paddle.to_tensor(input[0]) + self.real_img = paddle.to_tensor(input[1]) + + def forward(self, test_mode=False, regularize=False): + if test_mode: + self.fake_img, _ = self.nets['g_ema'](self.degraded_img) # G(A) + else: + if regularize: + self.fake_img, self.latents = self.nets['netG']( + self.degraded_img, return_latents=True) + else: + self.fake_img, _ = self.nets['netG'](self.degraded_img) + + def backward_D(self, regularize=False): + """Calculate GAN loss for the discriminator""" + if regularize: + self.real_img.stop_gradient = False + real_pred = self.nets['netD'](self.real_img) + r1_loss = d_r1_loss(real_pred, self.real_img) + (10 / 2 * r1_loss * 16).backward() + else: + fake_pred = self.nets['netD'](self.fake_img) + real_pred = self.nets['netD'](self.real_img) + self.loss_D = d_logistic_loss(real_pred, fake_pred) + self.loss_D.backward() + self.losses['D_loss'] = self.loss_D + + def backward_G(self, regularize): + """Calculate GAN and L1 loss for the generator""" + + if regularize: + path_loss, self.mean_path_length, path_lengths = g_path_regularize( + self.fake_img, self.latents, self.mean_path_length) + weighted_path_loss = 2 * 4 * path_loss + weighted_path_loss.backward() + else: + fake_pred = self.nets['netD'](self.fake_img) + self.loss_G = g_nonsaturating_loss(fake_pred, self.gan_criterions, + self.fake_img, self.real_img, + self.degraded_img) + self.loss_G.backward() + self.losses['G_loss'] = self.loss_G + + def train_iter(self, optimizers=None): + + self.current_iter += 1 + # update D + self.set_requires_grad(self.nets['netD'], True) + self.set_requires_grad(self.nets['netG'], False) + self.forward(test_mode=False) + optimizers['optimD'].clear_grad() + self.backward_D(regularize=False) + optimizers['optimD'].step() + + d_regularize = self.current_iter % 24 == 0 + if d_regularize: + optimizers['optimD'].clear_grad() + self.backward_D(regularize=True) + optimizers['optimD'].step() + # update G + self.set_requires_grad(self.nets['netD'], False) + self.set_requires_grad(self.nets['netG'], True) + self.forward(test_mode=False) + optimizers['optimG'].clear_grad() + self.backward_G(regularize=False) + optimizers['optimG'].step() + + g_regularize = self.current_iter % 4 == 0 + if g_regularize: + self.forward(test_mode=False, regularize=True) + optimizers['optimG'].clear_grad() + self.backward_G(regularize=True) + optimizers['optimG'].step() + + self.accumulate(self.nets['g_ema'], self.nets['netG'], self.accum) + + def test_iter(self, metrics=None): + self.nets['g_ema'].eval() + self.forward(test_mode=True) + + with paddle.no_grad(): + if metrics is not None: + for metric in metrics.values(): + metric.update(self.fake_img, self.real_img) + + def accumulate(self, model1, model2, decay=0.999): + par1 = dict(model1.state_dict()) + par2 = dict(model2.state_dict()) + + for k in par1.keys(): + par1[k] = par1[k] * decay + par2[k] * (1 - decay) + + model1.load_dict(par1) diff --git a/ppgan/modules/equalized.py b/ppgan/modules/equalized.py index 2ef60e6..9d14402 100644 --- a/ppgan/modules/equalized.py +++ b/ppgan/modules/equalized.py @@ -28,6 +28,7 @@ class EqualConv2D(nn.Layer): """This convolutional layer class stabilizes the learning rate changes of its parameters. Equalizing learning rate keeps the weights in the network at a similar scale during training. """ + def __init__(self, in_channel, out_channel, @@ -74,6 +75,7 @@ class EqualLinear(nn.Layer): """This linear layer class stabilizes the learning rate changes of its parameters. Equalizing learning rate keeps the weights in the network at a similar scale during training. """ + def __init__(self, in_dim, out_dim, @@ -115,3 +117,50 @@ class EqualLinear(nn.Layer): return ( f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})" ) + + +class EqualLinear_gpen(nn.Layer): + """This linear layer class stabilizes the learning rate changes of its parameters. + Equalizing learning rate keeps the weights in the network at a similar scale during training. + """ + + def __init__(self, + in_dim, + out_dim, + bias=True, + bias_init=0, + lr_mul=1, + activation=None): + super().__init__() + + self.weight = self.create_parameter( + (out_dim, in_dim), default_initializer=nn.initializer.Normal()) + self.weight.set_value((self.weight / lr_mul)) + + if bias: + self.bias = self.create_parameter( + (out_dim, ), nn.initializer.Constant(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, (self.weight * self.scale).t()) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear(input, (self.weight * self.scale).t(), + bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) diff --git a/test_tipc/configs/GPEN/train_infer_python.txt b/test_tipc/configs/GPEN/train_infer_python.txt new file mode 100644 index 0000000..2deb575 --- /dev/null +++ b/test_tipc/configs/GPEN/train_infer_python.txt @@ -0,0 +1,51 @@ +===========================train_params=========================== +model_name:GPEN +python:python3.7 +gpu_list:0 +## +auto_cast:null +total_iters:lite_train_lite_infer=10 +output_dir:./output/ +snapshot_config.interval:lite_train_lite_infer=10 +pretrained_model:null +train_model_name:gpen*/*checkpoint.pdparams +train_infer_img_dir:null +null:null +## +trainer:norm_train +norm_train:tools/main.py -c configs/gpen_256_ffhq.yaml --seed 100 -o log_config.interval=1 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +--output_dir:./output/ +load:null +norm_export:tools/export_model.py -c configs/gpen_256_ffhq.yaml --inputs_size=1,3,256,256 --model_name inference --load +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +inference_dir:inference +train_model:./inference/gpen/gpenmodel_g_ema +infer_export:null +infer_quant:False +inference:tools/inference.py --model_type GPEN --seed 100 -c configs/gpen_256_ffhq.yaml --output_path test_tipc/output/ -o dataset.test.amount=5 +--device:gpu +null:null +null:null +null:null +null:null +null:null +--model_path: +null:null +null:null +--benchmark:True +null:null diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 1af5e4c..edd4934 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -54,6 +54,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then rm -rf ./data/ffhq* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate cd ./data/ && tar xf ffhq.tar && cd ../ ;; + GPEN) + rm -rf ./data/ffhq* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate + cd ./data/ && tar xf ffhq.tar && cd ../ ;; FOMM) rm -rf ./data/fom_lite* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/fom_lite.tar --no-check-certificate --no-check-certificate @@ -106,6 +110,10 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then rm -rf ./data/ffhq* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate cd ./data/ && tar xf ffhq.tar && cd ../ + elif [ ${model_name} == "GPEN" ]; then + rm -rf ./data/ffhq* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate + cd ./data/ && tar xf ffhq.tar && cd ../ elif [ ${model_name} == "basicvsr" ]; then rm -rf ./data/reds* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/reds_lite.tar --no-check-certificate diff --git a/tools/export_model.py b/tools/export_model.py index ee2e563..42c5425 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -16,6 +16,8 @@ import os import sys import argparse +sys.path.append(".") + import ppgan from ppgan.utils.config import get_config from ppgan.utils.setup import setup @@ -76,7 +78,7 @@ def main(args, cfg): for net_name, net in model.nets.items(): if net_name in state_dicts: net.set_state_dict(state_dicts[net_name]) - model.export_model(cfg.export_model, args.output_dir, inputs_size, + model.export_model(cfg.export_model, args.output_dir, inputs_size, args.export_serving_model, args.model_name) diff --git a/tools/inference.py b/tools/inference.py index 6ee2108..a834742 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -4,6 +4,10 @@ import numpy as np import random import os from collections import OrderedDict +import sys +import cv2 + +sys.path.append(".") from ppgan.utils.config import get_config from ppgan.datasets.builder import build_dataloader @@ -15,7 +19,7 @@ from ppgan.metrics import build_metric MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \ - "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan","prenet"] + "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan","prenet","GPEN"] def parse_args(): @@ -313,7 +317,7 @@ def main(): metric_file = os.path.join(args.output_path, "singan/metric.txt") for metric in metrics.values(): metric.update(prediction, data['A']) - elif model_type == "prenet": + elif model_type == "prenet": lq = data['lq'].numpy() gt = data['gt'].numpy() input_handles[0].copy_from_cpu(lq) @@ -329,23 +333,31 @@ def main(): metric_file = os.path.join(args.output_path, "prenet/metric.txt") for metric in metrics.values(): metric.update(image_numpy, gt_img) - - elif model_type == "prenet": - lq = data['lq'].numpy() - gt = data['gt'].numpy() + elif model_type == "GPEN": + lq = data[0].numpy() input_handles[0].copy_from_cpu(lq) predictor.run() prediction = output_handle.copy_to_cpu() - prediction = paddle.to_tensor(prediction) - gt = paddle.to_tensor(gt) - image_numpy = tensor2img(prediction, min_max) - gt_img = tensor2img(gt, min_max) - save_image( - image_numpy, - os.path.join(args.output_path, "prenet/{}.png".format(i))) - metric_file = os.path.join(args.output_path, "prenet/metric.txt") + target = data[1].numpy() + + metric_file = os.path.join(args.output_path, model_type, + "metric.txt") for metric in metrics.values(): - metric.update(image_numpy, gt_img) + metric.update(prediction, target) + + lq = paddle.to_tensor(lq) + target = paddle.to_tensor(target) + prediction = paddle.to_tensor(prediction) + + lq = lq.transpose([0, 2, 3, 1]) + target = target.transpose([0, 2, 3, 1]) + prediction = prediction.transpose([0, 2, 3, 1]) + sample_result = paddle.concat((lq[0], prediction[0], target[0]), 1) + sample = cv2.cvtColor((sample_result.numpy() + 1) / 2 * 255, + cv2.COLOR_RGB2BGR) + file_name = os.path.join(args.output_path, model_type, + "{}.png".format(i)) + cv2.imwrite(file_name, sample) if metrics: log_file = open(metric_file, 'a') -- GitLab