From 45922b0c534dc44a4969aa5ab409647552cc4729 Mon Sep 17 00:00:00 2001 From: yangshurong <73787862+yangshurong@users.noreply.github.com> Date: Fri, 14 Oct 2022 10:15:50 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E8=AE=BA=E6=96=87=E5=A4=8D=E7=8E=B0?= =?UTF-8?q?=E3=80=91GFPGAN=20(#703)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * gfpgan push * gfpgan finish * gfpgan add init * gfpgan del recover * 11111 * gfpgan change name * gfpgan recover name * Update GFPGAN.md Co-authored-by: wangna11BD <79366697+wangna11BD@users.noreply.github.com> --- configs/gfpgan_ffhq1024.yaml | 205 +++ docs/en_US/tutorials/GFPGAN.md | 207 +++ docs/zh_CN/tutorials/GFPGAN.md | 198 +++ ppgan/datasets/__init__.py | 2 + ppgan/datasets/gfpgan_datasets.py | 202 +++ ppgan/datasets/paired_image_datasets.py | 135 ++ .../face_enhancement/gfpgan_enhance.py | 87 + ppgan/models/__init__.py | 1 + ppgan/models/criterions/__init__.py | 1 + ppgan/models/criterions/gfpgan_loss.py | 427 +++++ ppgan/models/discriminators/__init__.py | 1 + .../discriminators/arcface_arch_paddle.py | 285 ++++ ppgan/models/generators/__init__.py | 2 + ppgan/models/generators/gfpganv1_arch.py | 1418 +++++++++++++++++ .../models/generators/gfpganv1_clean_arch.py | 329 ++++ .../models/generators/stylegan2_clean_arch.py | 396 +++++ ppgan/models/gfpgan_model.py | 552 +++++++ ppgan/utils/gfpgan_tools.py | 1127 +++++++++++++ .../configs/GFPGAN/train_infer_python.txt | 59 + test_tipc/prepare.sh | 5 + tools/inference.py | 9 + 21 files changed, 5648 insertions(+) create mode 100644 configs/gfpgan_ffhq1024.yaml create mode 100644 docs/en_US/tutorials/GFPGAN.md create mode 100644 docs/zh_CN/tutorials/GFPGAN.md mode change 100755 => 100644 ppgan/datasets/__init__.py create mode 100644 ppgan/datasets/gfpgan_datasets.py create mode 100644 ppgan/datasets/paired_image_datasets.py create mode 100644 ppgan/faceutils/face_enhancement/gfpgan_enhance.py create mode 100644 ppgan/models/criterions/gfpgan_loss.py create mode 100644 ppgan/models/discriminators/arcface_arch_paddle.py mode change 100755 => 100644 ppgan/models/generators/__init__.py create mode 100644 ppgan/models/generators/gfpganv1_arch.py create mode 100644 ppgan/models/generators/gfpganv1_clean_arch.py create mode 100644 ppgan/models/generators/stylegan2_clean_arch.py create mode 100644 ppgan/models/gfpgan_model.py create mode 100644 ppgan/utils/gfpgan_tools.py create mode 100644 test_tipc/configs/GFPGAN/train_infer_python.txt diff --git a/configs/gfpgan_ffhq1024.yaml b/configs/gfpgan_ffhq1024.yaml new file mode 100644 index 0000000..8104286 --- /dev/null +++ b/configs/gfpgan_ffhq1024.yaml @@ -0,0 +1,205 @@ +total_iters: 800000 +output_dir: output +find_unused_parameters: True + +log_config: + interval: 100 + visiual_interval: 100 + +snapshot_config: + interval: 30000 + +enable_visualdl: False + +validate: + interval: 5000 + save_img: True + + metrics: + psnr: + name: PSNR + crop_border: 0 + test_y_channel: false + fid: + name: FID + batch_size: 8 +model: + name: GFPGANModel + network_g: + name: GFPGANv1 + out_size: 512 + num_style_feat: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + decoder_load_path: https://paddlegan.bj.bcebos.com/models/StyleGAN2_FFHQ512_Cmul1.pdparams + fix_decoder: true + num_mlp: 8 + lr_mlp: 0.01 + input_is_latent: true + different_w: true + narrow: 1 + sft_half: true + network_d: + name: StyleGAN2DiscriminatorGFPGAN + out_size: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + network_d_left_eye: + type: FacialComponentDiscriminator + + network_d_right_eye: + type: FacialComponentDiscriminator + + network_d_mouth: + type: FacialComponentDiscriminator + + network_identity: + name: ResNetArcFace + block: IRBlock + layers: [2, 2, 2, 2] + use_se: False + + path: + image_visual: gfpgan_train_outdir + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: ~ + pretrain_network_d: ~ + pretrain_network_d_left_eye: https://paddlegan.bj.bcebos.com/models/Facial_component_discriminator.pdparams + pretrain_network_d_right_eye: https://paddlegan.bj.bcebos.com/models/Facial_component_discriminator.pdparams + pretrain_network_d_mouth: https://paddlegan.bj.bcebos.com/models/Facial_component_discriminator.pdparams + pretrain_network_identity: https://paddlegan.bj.bcebos.com/models/arcface_resnet18.pdparams + + + # losses + # pixel loss + pixel_opt: + name: GFPGANL1Loss + loss_weight: !!float 1e-1 + reduction: mean + # L1 loss used in pyramid loss, component style loss and identity loss + L1_opt: + name: GFPGANL1Loss + loss_weight: 1 + reduction: mean + + # image pyramid loss + pyramid_loss_weight: 1 + remove_pyramid_loss: 50000 + # perceptual loss (content and style losses) + perceptual_opt: + name: GFPGANPerceptualLoss + layer_weights: + # before relu + "conv1_2": 0.1 + "conv2_2": 0.1 + "conv3_4": 1 + "conv4_4": 1 + "conv5_4": 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1 + style_weight: 50 + range_norm: true + criterion: l1 + # gan loss + gan_opt: + name: GFPGANGANLoss + gan_type: wgan_softplus + loss_weight: !!float 1e-1 + # r1 regularization for discriminator + r1_reg_weight: 10 + # facial component loss + gan_component_opt: + name: GFPGANGANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1 + comp_style_weight: 200 + # identity loss + identity_weight: 10 + + net_d_iters: 1 + net_d_init_iters: 0 + net_d_reg_every: 16 + +export_model: + - { name: "net_g_ema", inputs_num: 1 } + +dataset: + train: + name: FFHQDegradationDataset + dataroot_gt: data/gfpgan_data/train + io_backend: + type: disk + + use_hflip: true + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + out_size: 512 + + blur_kernel_size: 41 + kernel_list: ["iso", "aniso"] + kernel_prob: [0.5, 0.5] + blur_sigma: [0.1, 10] + downsample_range: [0.8, 8] + noise_range: [0, 20] + jpeg_range: [60, 100] + + # color jitter and gray + color_jitter_prob: 0.3 + color_jitter_shift: 20 + color_jitter_pt_prob: 0.3 + gray_prob: 0.01 + + # If you do not want colorization, please set + # color_jitter_prob: ~ + # color_jitter_pt_prob: ~ + # gray_prob: 0.01 + # gt_gray: True + + crop_components: true + component_path: https://paddlegan.bj.bcebos.com/models/FFHQ_eye_mouth_landmarks_512.pdparams + eye_enlarge_ratio: 1.4 + + # data loader + use_shuffle: true + num_workers: 4 + batch_size: 1 + prefetch_mode: ~ + + test: + # Please modify accordingly to use your own validation + # Or comment the val block if do not need validation during training + name: PairedImageDataset + dataroot_lq: data/gfpgan_data/lq + dataroot_gt: data/gfpgan_data/gt + io_backend: + type: disk + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + scale: 1 + num_workers: 4 + batch_size: 8 + phase: val + +lr_scheduler: + name: MultiStepDecay + learning_rate: 0.002 + milestones: [600000, 700000] + gamma: 0.5 + +optimizer: + optim_g: + name: Adam + beta1: 0 + beta2: 0.99 + optim_d: + name: Adam + beta1: 0 + beta2: 0.99 + optim_component: + name: Adam + beta1: 0.9 + beta2: 0.99 diff --git a/docs/en_US/tutorials/GFPGAN.md b/docs/en_US/tutorials/GFPGAN.md new file mode 100644 index 0000000..d4ca57b --- /dev/null +++ b/docs/en_US/tutorials/GFPGAN.md @@ -0,0 +1,207 @@ +## GFPGAN Blind Face Restoration Model + + + +## 1、Introduction + +GFP-GAN that leverages rich and diverse priors encapsulated in a pretrained face GAN for blind face restoration. +### Overview of GFP-GAN framework: + +![image](https://user-images.githubusercontent.com/73787862/191736718-72f5aa09-d7a9-490b-b1f8-b609208d4654.png) + +GFP-GAN is comprised of a degradation removal +module (U-Net) and a pretrained face GAN (such as StyleGAN2) as prior. They are bridged by a latent code +mapping and several Channel-Split Spatial Feature Transform (CS-SFT) layers. + +By dealing with features, it achieving realistic results while preserving high fidelity. + +For a more detailed introduction to the model, and refer to the repo, you can view the following AI Studio project +[https://aistudio.baidu.com/aistudio/projectdetail/4421649](https://aistudio.baidu.com/aistudio/projectdetail/4421649) + +In this experiment, We train +our model with Adam optimizer for a total of 210k iterations. + +The result of experiments of recovering of GFPGAN as following: + +Model | LPIPS | FID | PSNR +--- |:---:|:---:|:---:| +GFPGAN | 0.3817 | 36.8068 | 65.0461 + +## 2、Ready to work + +### 2.1 Dataset Preparation + +The GFPGAN model training set is the classic FFHQ face data set, +with a total of 70,000 high-resolution 1024 x 1024 high-resolution face pictures, +and the test set is the CELEBA-HQ data set, with a total of 2,000 high-resolution face pictures. The generation way is the same as that during training. +For details, please refer to **Dataset URL:** [FFHQ](https://github.com/NVlabs/ffhq-dataset), [CELEBA-HQ](https://github.com/tkarras/progressive_growing_of_gans). +The specific download links are given below: + +**Original dataset download address:** + +**FFHQ :** https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open + +**CELEBA-HQ:** https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs?resourcekey=0-arAVTUfW9KRhN-irJchVKQ&usp=sharing + +The structure of data as following + +``` +|-- data/GFPGAN + |-- train + |-- 00000.png + |-- 00001.png + |-- ...... + |-- 00999.png + |-- ...... + |-- 69999.png + |-- lq + |-- 2000张jpg图片 + |-- gt + |-- 2000张jpg图片 +``` + + +Please modify the dataroot parameters of dataset train and test in the configs/gfpgan_ffhq1024.yaml configuration file to your training set and test set path. + + +### 2.2 Model preparation + +**Model parameter file and training log download address:** + +https://paddlegan.bj.bcebos.com/models/GFPGAN.pdparams + +Download the model parameters and test images from the link and put them in the data/ folder in the project root directory. The specific file structure is as follows: + +the params is a dict(one type in python),and could be load by paddlepaddle. It contains key (net_g,net_g_ema),you can use any of one to inference + +## 3、Start using + +### 3.1 model training + +Enter the following code in the console to start training: + + ```bash + python tools/main.py -c configs/gfpgan_ffhq1024.yaml + ``` + +The model supports single-card training and multi-card training.So you can use this bash to train + + ```bash +!CUDA_VISIBLE_DEVICES=0,1,2,3 +!python -m paddle.distributed.launch tools/main.py \ + --config-file configs/gpfgan_ffhq1024.yaml + ``` + +Model training needs to use paddle2.3 and above, and wait for paddle to implement the second-order operator related functions of elementwise_pow. The paddle2.2.2 version can run normally, but the model cannot be successfully trained because some loss functions will calculate the wrong gradient. . If an error is reported during training, training is not supported for the time being. You can skip the training part and directly use the provided model parameters for testing. Model evaluation and testing can use paddle2.2.2 and above. + + + +### 3.2 Model evaluation + +When evaluating the model, enter the following code in the console, using the downloaded model parameters mentioned above: + + ```shell +python tools/main.py -c configs/gfpgan_ffhq1024.yaml --load GFPGAN.pdparams --evaluate-only + ``` + +If you want to test on your own provided model, please modify the path after --load . + + + +### 3.3 Model prediction + +#### 3.3.1 Export model + +After training, you need to use ``tools/export_model.py`` to extract the weights of the generator from the trained model (including the generator only) +Enter the following command to extract the model of the generator: + +```bash +python -u tools/export_model.py --config-file configs/gfpgan_ffhq1024.yaml \ + --load GFPGAN.pdparams \ + --inputs_size 1,3,512,512 +``` + + +#### 3.3.2 Process a single image + +You can use our tools in ppgan/faceutils/face_enhancement/gfpgan_enhance.py to inferences one picture quickly +```python +%env PYTHONPATH=.:$PYTHONPATH +%env CUDA_VISIBLE_DEVICES=0 +import paddle +import cv2 +import numpy as np +import sys +from ppgan.faceutils.face_enhancement.gfpgan_enhance import gfp_FaceEnhancement +# you can use your path +img_path='test/2.png' +img = cv2.imread(img_path, cv2.IMREAD_COLOR) +# this is origin picture +cv2.imwrite('test/outlq.png',img) +img=np.array(img).astype('float32') +faceenhancer = gfp_FaceEnhancement() +img = faceenhancer.enhance_from_image(img) +# the result of prediction +cv2.imwrite('test/out_gfpgan.png',img) +``` + +![image](https://user-images.githubusercontent.com/73787862/191741112-b813a02c-6b19-4591-b80d-0bf5ce8ad07e.png) +![image](https://user-images.githubusercontent.com/73787862/191741242-1f365048-ba25-450f-8abc-76e74d8786f8.png) + + + + +## 4. Tipc + +### 4.1 Export the inference model + +```bash +python -u tools/export_model.py --config-file configs/gfpgan_ffhq1024.yaml \ + --load GFPGAN.pdparams \ + --inputs_size 1,3,512,512 +``` + +You can also modify the parameters after --load to the model parameter file you want to test. + + + +### 4.2 Inference with a prediction engine + +```bash +%cd /home/aistudio/work/PaddleGAN +# %env PYTHONPATH=.:$PYTHONPATH +# %env CUDA_VISIBLE_DEVICES=0 +!python -u tools/inference.py --config-file configs/gfpgan_ffhq1024.yaml \ + --model_path GFPGAN.pdparams \ + --model_type gfpgan \ + --device gpu \ + -o validate=None +``` + + +### 4.3 Call the script to complete the training and push test in two steps + +To invoke the `lite_train_lite_infer` mode of the foot test base training prediction function, run: + +```bash +%cd /home/aistudio/work/PaddleGAN +!bash test_tipc/prepare.sh \ + test_tipc/configs/GFPGAN/train_infer_python.txt \ + lite_train_lite_infer +!bash test_tipc/test_train_inference_python.sh \ + test_tipc/configs/GFPGAN/train_infer_python.txt \ + lite_train_lite_infer +``` + + + +## 5、References + +``` +@InProceedings{wang2021gfpgan, + author = {Xintao Wang and Yu Li and Honglun Zhang and Ying Shan}, + title = {Towards Real-World Blind Face Restoration with Generative Facial Prior}, + booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2021} +} +``` diff --git a/docs/zh_CN/tutorials/GFPGAN.md b/docs/zh_CN/tutorials/GFPGAN.md new file mode 100644 index 0000000..1da5317 --- /dev/null +++ b/docs/zh_CN/tutorials/GFPGAN.md @@ -0,0 +1,198 @@ +## GFPGAN 盲脸复原模型 + + +## 1、介绍 +GFP-GAN利用丰富和多样化的先验封装在预先训练的面部GAN用于盲人面部恢复。 +### GFPGAN的整体结构: + +![image](https://user-images.githubusercontent.com/73787862/191736718-72f5aa09-d7a9-490b-b1f8-b609208d4654.png) + +GFP-GAN由降解去除物组成 +模块(U-Net)和预先训练的面部GAN(如StyleGAN2)作为先验。他们之间有隐藏的密码 +映射和几个通道分割空间特征变换(CS-SFT)层。 + +通过处理特征,它在保持高保真度的同时实现了真实的结果。 + +要了解更详细的模型介绍,并参考回购,您可以查看以下AI Studio项目 +[基于PaddleGAN复现GFPGAN](https://aistudio.baidu.com/aistudio/projectdetail/4421649) + +在这个实验中,我们训练 +我们的模型和Adam优化器共进行了210k次迭代。 + +GFPGAN的回收实验结果如下: + + +Model | LPIPS | FID | PSNR +--- |:---:|:---:|:---:| +GFPGAN | 0.3817 | 36.8068 | 65.0461 + +## 2、准备工作 + +### 2.1 数据集准备 + +GFPGAN模型训练集是经典的FFHQ人脸数据集, +总共有7万张高分辨率1024 x 1024的人脸图片, +测试集为CELEBA-HQ数据集,共有2000张高分辨率人脸图片。生成方式与训练时相同。 +For details, please refer to **Dataset URL:** [FFHQ](https://github.com/NVlabs/ffhq-dataset), [CELEBA-HQ](https://github.com/tkarras/progressive_growing_of_gans). +The specific download links are given below: + +**原始数据集地址:** + +**FFHQ :** https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open + +**CELEBA-HQ:** https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs?resourcekey=0-arAVTUfW9KRhN-irJchVKQ&usp=sharing + +数据集结构如下 + +``` +|-- data/GFPGAN + |-- train + |-- 00000.png + |-- 00001.png + |-- ...... + |-- 00999.png + |-- ...... + |-- 69999.png + |-- lq + |-- 2000张jpg图片 + |-- gt + |-- 2000张jpg图片 +``` + +请在configs/gfpgan_ffhq1024. data中修改数据集train和test的dataroot参数。Yaml配置文件到您的训练集和测试集路径。 + +### 2.2 模型准备 +**模型参数文件和训练日志下载地址:** + +https://paddlegan.bj.bcebos.com/models/GFPGAN.pdparams + +从链接下载模型参数和测试图像,并将它们放在项目根目录中的data/文件夹中。具体文件结构如下: + +params是一个dict(python中的一种类型),可以通过paddlepaddle加载。它包含key (net_g,net_g_ema),您可以使用其中任何一个来进行推断 + +## 3、开始使用 +模型训练 + +在控制台中输入以下代码开始训练: + + ```bash + python tools/main.py -c configs/gfpgan_ffhq1024.yaml + ``` + +该模型支持单卡训练和多卡训练。 +也可以使用如下命令进行多卡训练 + +```bash +!CUDA_VISIBLE_DEVICES=0,1,2,3 +!python -m paddle.distributed.launch tools/main.py \ + --config-file configs/gpfgan_ffhq1024.yaml +``` + +模型训练需要使用paddle2.3及以上版本,等待paddle实现elementwise_pow的二阶算子相关函数。paddle2.2.2版本可以正常运行,但由于某些损失函数会计算出错误的梯度,无法成功训练模型。如果在培训过程中报错,则暂时不支持培训。您可以跳过训练部分,直接使用提供的模型参数进行测试。模型评估和测试可以使用paddle2.2.2及以上版本。 + +### 3.2 模型评估 + +当评估模型时,在控制台中输入以下代码,使用上面提到的下载的模型参数: + + ```shell +python tools/main.py -c configs/gfpgan_ffhq1024.yaml --load GFPGAN.pdparams --evaluate-only + ``` + +当评估模型时,在控制台中输入以下代码,使用下载的模型。如果您想在您自己提供的模型上进行测试,请修改之后的路径 --load . + + + +### 3.3 模型预测 + +#### 3.3.1 导出模型 + +在训练之后,您需要使用' ' tools/export_model.py ' '从训练的模型中提取生成器的权重(仅包括生成器) +输入以下命令提取生成器的模型: + +```bash +python -u tools/export_model.py --config-file configs/gfpgan_ffhq1024.yaml \ + --load GFPGAN.pdparams \ + --inputs_size 1,3,512,512 +``` + + +#### 3.3.2 加载一张图片 + +你可以使用我们在ppgan/faceutils/face_enhancement/gfpgan_enhance.py中的工具来快速推断一张图片 + +```python +%env PYTHONPATH=.:$PYTHONPATH +%env CUDA_VISIBLE_DEVICES=0 +import paddle +import cv2 +import numpy as np +import sys +from ppgan.faceutils.face_enhancement.gfpgan_enhance import gfp_FaceEnhancement +# 图片路径可以用自己的 +img_path='test/2.png' +img = cv2.imread(img_path, cv2.IMREAD_COLOR) +# 这是原来的模糊图片 +cv2.imwrite('test/outlq.png',img) +img=np.array(img).astype('float32') +faceenhancer = gfp_FaceEnhancement() +img = faceenhancer.enhance_from_image(img) +# 这是生成的清晰图片 +cv2.imwrite('test/out_gfpgan.png',img) +``` + +![image](https://user-images.githubusercontent.com/73787862/191741112-b813a02c-6b19-4591-b80d-0bf5ce8ad07e.png) +![image](https://user-images.githubusercontent.com/73787862/191741242-1f365048-ba25-450f-8abc-76e74d8786f8.png) + + + + +## 4. Tipc + +### 4.1 导出推理模型 + +```bash +python -u tools/export_model.py --config-file configs/gfpgan_ffhq1024.yaml \ + --load GFPGAN.pdparams \ + --inputs_size 1,3,512,512 +``` + +### 4.2 使用paddleInference推理 + +```bash +%cd /home/aistudio/work/PaddleGAN +# %env PYTHONPATH=.:$PYTHONPATH +# %env CUDA_VISIBLE_DEVICES=0 +!python -u tools/inference.py --config-file configs/gfpgan_ffhq1024.yaml \ + --model_path GFPGAN.pdparams \ + --model_type gfpgan \ + --device gpu \ + -o validate=None +``` + + +### 4.3 一键TIPC + +调用足部测试基础训练预测函数的' lite_train_lite_infer '模式,执行: + +```bash +%cd /home/aistudio/work/PaddleGAN +!bash test_tipc/prepare.sh \ + test_tipc/configs/GFPGAN/train_infer_python.txt \ + lite_train_lite_infer +!bash test_tipc/test_train_inference_python.sh \ + test_tipc/configs/GFPGAN/train_infer_python.txt \ + lite_train_lite_infer +``` + + + +## 5、References + +``` +@InProceedings{wang2021gfpgan, + author = {Xintao Wang and Yu Li and Honglun Zhang and Ying Shan}, + title = {Towards Real-World Blind Face Restoration with Generative Facial Prior}, + booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2021} +} +``` diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py old mode 100755 new mode 100644 index 175bd58..3e5a487 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -32,4 +32,6 @@ from .photopen_dataset import PhotoPenDataset from .empty_dataset import EmptyDataset from .gpen_dataset import GPENDataset from .swinir_dataset import SwinIRDataset +from .gfpgan_datasets import FFHQDegradationDataset +from .paired_image_datasets import PairedImageDataset from .invdn_dataset import InvDNDataset diff --git a/ppgan/datasets/gfpgan_datasets.py b/ppgan/datasets/gfpgan_datasets.py new file mode 100644 index 0000000..5f34ac1 --- /dev/null +++ b/ppgan/datasets/gfpgan_datasets.py @@ -0,0 +1,202 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cv2 +import math +import numpy as np +import random +import os + +import paddle +import paddle.nn.functional as F +from paddle.vision.transforms.functional import normalize + +from .builder import DATASETS + +from ppgan.utils.download import get_path_from_url +from ppgan.utils.gfpgan_tools import * + + +@DATASETS.register() +class FFHQDegradationDataset(paddle.io.Dataset): + """FFHQ dataset for GFPGAN. + + It reads high resolution images, and then generate low-quality (LQ) images on-the-fly. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + io_backend (dict): IO backend type and other kwarg. + mean (list | tuple): Image mean. + std (list | tuple): Image std. + use_hflip (bool): Whether to horizontally flip. + Please see more options in the codes. + """ + def __init__(self, **opt): + super(FFHQDegradationDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.gt_folder = opt['dataroot_gt'] + self.mean = opt['mean'] + self.std = opt['std'] + self.out_size = opt['out_size'] + self.crop_components = opt.get('crop_components', False) + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) + if self.crop_components: + self.components_list = get_path_from_url(opt.get('component_path')) + self.components_list = paddle.load(self.components_list) + # print(self.components_list) + self.paths = paths_from_folder(self.gt_folder) + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.blur_sigma = opt['blur_sigma'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + self.color_jitter_prob = opt.get('color_jitter_prob') + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + self.gray_prob = opt.get('gray_prob') + self.color_jitter_shift /= 255.0 + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = paddle.randperm(4) + img = paddle.to_tensor(img, dtype=img.dtype) + for fn_id in fn_idx: + # print('fn_id',fn_id) + if fn_id == 0 and brightness is not None: + brightness_factor = paddle.to_tensor(1.0).uniform_( + brightness[0], brightness[1]).item() + # print("brightness_factor",brightness_factor) + img = adjust_brightness(img, brightness_factor) + if fn_id == 1 and contrast is not None: + contrast_factor = paddle.to_tensor(1.0).uniform_( + contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + if fn_id == 2 and saturation is not None: + saturation_factor = paddle.to_tensor(1.0).uniform_( + saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + if fn_id == 3 and hue is not None: + hue_factor = paddle.to_tensor(1.0).uniform_(hue[0], + hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + def get_component_coordinates(self, index, status): + """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file""" + # print(f'{index:08d}',type(self.components_list)) + components_bbox = self.components_list[f'{index:08d}'] + if status[0]: + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + components_bbox['left_eye'][ + 0] = self.out_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][ + 0] = self.out_size - components_bbox['right_eye'][0] + components_bbox['mouth'][ + 0] = self.out_size - components_bbox['mouth'][0] + locations = [] + for part in ['left_eye', 'right_eye', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = paddle.to_tensor(loc) + locations.append(loc) + return locations + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), + **self.io_backend_opt) + gt_path = self.paths[index] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + img_gt = cv2.resize(img_gt, (self.out_size, self.out_size)) + img_gt, status = augment(img_gt, + hflip=self.opt['use_hflip'], + rotation=False, + return_status=True) + h, w, _ = img_gt.shape + if self.crop_components: + locations = self.get_component_coordinates(index, status) + loc_left_eye, loc_right_eye, loc_mouth = locations + kernel = random_mixed_kernels(self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + noise_range=None) + img_lq = cv2.filter2D(img_gt, -1, kernel) + scale = np.random.uniform(self.downsample_range[0], + self.downsample_range[1]) + img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), + interpolation=cv2.INTER_LINEAR) + if self.noise_range is not None: + img_lq = random_add_gaussian_noise(img_lq, self.noise_range) + if self.jpeg_range is not None: + img_lq = random_add_jpg_compression(img_lq, self.jpeg_range) + img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) + if self.color_jitter_prob is not None and np.random.uniform( + ) < self.color_jitter_prob: + img_lq = self.color_jitter(img_lq, self.color_jitter_shift) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) + img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) + if self.opt.get('gt_gray'): + img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) + img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + if self.color_jitter_pt_prob is not None and np.random.uniform( + ) < self.color_jitter_pt_prob: + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_lq = self.color_jitter_pt(img_lq, brightness, contrast, + saturation, hue) + img_lq = np.clip((img_lq * 255.0).round(), 0, 255) / 255.0 + img_gt = normalize(img_gt, self.mean, self.std) + img_lq = normalize(img_lq, self.mean, self.std) + if self.crop_components: + return_dict = { + 'lq': img_lq, + 'gt': img_gt, + 'gt_path': gt_path, + 'loc_left_eye': loc_left_eye, + 'loc_right_eye': loc_right_eye, + 'loc_mouth': loc_mouth + } + return return_dict + else: + return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/ppgan/datasets/paired_image_datasets.py b/ppgan/datasets/paired_image_datasets.py new file mode 100644 index 0000000..bdaae3c --- /dev/null +++ b/ppgan/datasets/paired_image_datasets.py @@ -0,0 +1,135 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle.vision.transforms.functional import normalize + +from .builder import DATASETS +from ppgan.utils.gfpgan_tools import * + + +@DATASETS.register() +class PairedImageDataset(paddle.io.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info_file': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + def __init__(self, **opt): + super(PairedImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + if 'filename_tmpl' in opt: + self.filename_tmpl = opt['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb( + [self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info_file' in self.opt and self.opt[ + 'meta_info_file'] is not None: + self.paths = paired_paths_from_meta_info_file( + [self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.opt['meta_info_file'], self.filename_tmpl) + else: + self.paths = paired_paths_from_folder( + [self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), + **self.io_backend_opt) + # print(self.file_client) + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, + gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], + self.opt['use_rot']) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] + img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] + + # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets + # TODO: It is better to update the datasets, rather than force to crop + if self.opt['phase'] != 'train': + img_gt = img_gt[0:img_lq.shape[0] * scale, + 0:img_lq.shape[1] * scale, :] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + # normalize + + if self.mean is not None or self.std is not None: + img_lq = normalize(img_lq, self.mean, self.std) + img_gt = normalize(img_gt, self.mean, self.std) + + return { + 'lq': img_lq, + 'gt': img_gt, + 'lq_path': lq_path, + 'gt_path': gt_path + } + + def __len__(self): + return len(self.paths) diff --git a/ppgan/faceutils/face_enhancement/gfpgan_enhance.py b/ppgan/faceutils/face_enhancement/gfpgan_enhance.py new file mode 100644 index 0000000..707e5c1 --- /dev/null +++ b/ppgan/faceutils/face_enhancement/gfpgan_enhance.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import cv2 +import numpy as np +import sys + +import paddle +import paddle.nn as nn + +from ppgan.utils.visual import * +from ppgan.utils.download import get_path_from_url +from ppgan.models.generators import GFPGANv1Clean +from ppgan.models.generators import GFPGANv1 +from ppgan.faceutils.face_detection.detection.blazeface.utils import * +GFPGAN_weights = 'https://paddlegan.bj.bcebos.com/models/GFPGAN.pdparams' + + +class gfp_FaceEnhancement(object): + def __init__(self, size=512, batch_size=1): + super(gfp_FaceEnhancement, self).__init__() + + # Initialise the face detector + model_weights_path = get_path_from_url(GFPGAN_weights) + model_weights = paddle.load(model_weights_path) + + self.face_enhance = GFPGANv1(out_size=512, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=[1, 3, 3, 1], + decoder_load_path=None, + fix_decoder=True, + num_mlp=8, + lr_mlp=0.01, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + self.face_enhance.load_dict(model_weights['net_g_ema']) + self.face_enhance.eval() + self.size = size + self.mask = np.zeros((512, 512), np.float32) + cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, + cv2.LINE_AA) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) + self.mask = paddle.tile(paddle.to_tensor( + self.mask).unsqueeze(0).unsqueeze(-1), + repeat_times=[batch_size, 1, 1, 3]).numpy() + + def enhance_from_image(self, img): + if isinstance(img, np.ndarray): + img, _ = resize_and_crop_image(img, 512) + img = paddle.to_tensor(img).transpose([2, 0, 1]) + + else: + assert img.shape == [3, 512, 512] + return self.enhance_from_batch(img.unsqueeze(0))[0] + + def enhance_from_batch(self, img): + if isinstance(img, np.ndarray): + img_ori, _ = resize_and_crop_batch(img, 512) + img = paddle.to_tensor(img_ori).transpose([0, 3, 1, 2]) + else: + assert img.shape[1:] == [3, 512, 512] + img_ori = img.transpose([0, 2, 3, 1]).numpy() + img_t = (img / 255. - 0.5) / 0.5 + + with paddle.no_grad(): + out, __ = self.face_enhance(img_t) + image_tensor = out * 0.5 + 0.5 + image_tensor = image_tensor.transpose([0, 2, 3, 1]) # RGB + image_numpy = paddle.clip(image_tensor, 0, 1) * 255.0 + + out = image_numpy.astype(np.uint8).cpu().numpy() + return out * self.mask + (1 - self.mask) * img_ori diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 887adeb..24990dc 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -39,4 +39,5 @@ from .rcan_model import RCANModel from .prenet_model import PReNetModel from .gpen_model import GPENModel from .swinir_model import SwinIRModel +from .gfpgan_model import GFPGANModel from .invdn_model import InvDNModel diff --git a/ppgan/models/criterions/__init__.py b/ppgan/models/criterions/__init__.py index a4c6e81..fdd609c 100644 --- a/ppgan/models/criterions/__init__.py +++ b/ppgan/models/criterions/__init__.py @@ -10,3 +10,4 @@ from .builder import build_criterion from .ssim import SSIM from .id_loss import IDLoss +from .gfpgan_loss import GFPGANGANLoss, GFPGANL1Loss, GFPGANPerceptualLoss diff --git a/ppgan/models/criterions/gfpgan_loss.py b/ppgan/models/criterions/gfpgan_loss.py new file mode 100644 index 0000000..1e66a93 --- /dev/null +++ b/ppgan/models/criterions/gfpgan_loss.py @@ -0,0 +1,427 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cv2 +import math +import numpy as np +from collections import OrderedDict +import os + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.vision.models import vgg + +from .builder import CRITERIONS +from ppgan.utils.download import get_path_from_url + +VGG_PRETRAIN_PATH = os.path.join(os.getcwd(), 'pretrain', 'vgg19' + '.pdparams') +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', + 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', + 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', + 'relu5_2', 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', + 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', + 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', + 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', + 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', + 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', + 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', + 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', + 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +class VGGFeatureExtractor(nn.Layer): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + def __init__( + self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2, + pretrained_url='https://paddlegan.bj.bcebos.com/models/vgg19.pdparams' + ): + super(VGGFeatureExtractor, self).__init__() + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + weight_path = get_path_from_url(pretrained_url) + state_dict = paddle.load(weight_path) + vgg_net.set_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + features = vgg_net.features[:max_idx + 1] + self.vgg_layers = nn.Sequential() + for k, v in zip(self.names, features): + if 'pool' in k: + if remove_pooling: + continue + else: + self.vgg_layers.add_sublayer( + k, nn.MaxPool2D(kernel_size=2, stride=pooling_stride)) + else: + self.vgg_layers.add_sublayer(k, v) + + if not requires_grad: + self.vgg_layers.eval() + for param in self.parameters(): + param.stop_gradient = True + else: + self.vgg_layers.train() + for param in self.parameters(): + param.stop_gradient = False + if self.use_input_norm: + self.register_buffer( + 'mean', + paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1])) + self.register_buffer( + 'std', + paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1])) + + def forward(self, x, rep=None): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for name, module in self.vgg_layers.named_children(): + x = module(x) + if name in self.layer_name_list: + output[name] = x.clone() + return output + + +@CRITERIONS.register() +class GFPGANPerceptualLoss(nn.Layer): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0.0, + criterion='l1'): + super(GFPGANPerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor(layer_name_list=list( + layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = paddle.nn.L1Loss() + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError( + f'{criterion} criterion has not been supported.') + + def forward(self, x, gt, rep=None): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + x_features = self.vgg(x, rep) + gt_features = self.vgg(gt.detach()) + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += paddle.linalg.norm( + x_features[k] - gt_features[k], + p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion( + x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += paddle.linalg.norm( + self._gram_mat(x_features[k]) - + self._gram_mat(gt_features[k]), + p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion( + self._gram_mat(x_features[k]), + self._gram_mat(gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + (n, c, h, w) = x.shape + features = x.reshape([n, c, w * h]) + features_t = features.transpose([0, 2, 1]) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +@CRITERIONS.register() +class GFPGANGANLoss(nn.Layer): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + def __init__(self, + gan_type, + real_label_val=1.0, + fake_label_val=0.0, + loss_weight=1.0): + super(GFPGANGANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError( + f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + + return F.softplus(-1.0 * + input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val + if target_is_real else self.fake_label_val) + return paddle.ones(input.shape, dtype=input.dtype) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@CRITERIONS.register() +class GFPGANL1Loss(nn.Layer): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + def __init__(self, loss_weight=1.0, reduction='mean'): + super(GFPGANL1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError( + f'Unsupported reduction mode: {reduction}. Supported ones are: "none" | "mean" | "sum"' + ) + + self.loss_weight = loss_weight + self.l1_loss = paddle.nn.L1Loss(reduction) + + def forward(self, pred, target): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * self.l1_loss(pred, target) diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index bacedac..9d64778 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -25,3 +25,4 @@ from .discriminator_firstorder import FirstOrderDiscriminator from .discriminator_lapstyle import LapStyleDiscriminator from .discriminator_photopen import MultiscaleDiscriminator from .discriminator_singan import SinGANDiscriminator +from .arcface_arch_paddle import ResNetArcFace diff --git a/ppgan/models/discriminators/arcface_arch_paddle.py b/ppgan/models/discriminators/arcface_arch_paddle.py new file mode 100644 index 0000000..5a84653 --- /dev/null +++ b/ppgan/models/discriminators/arcface_arch_paddle.py @@ -0,0 +1,285 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .builder import DISCRIMINATORS + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2D(inplanes, + outplanes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + + +class BasicBlock(nn.Layer): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2D(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2D(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class IRBlock(nn.Layer): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2D(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2D(inplanes) + self.prelu = PReLU_layer() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2D(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.prelu(out) + return out + + +class Bottleneck(nn.Layer): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=1, bias_attr=False) + self.bn1 = nn.BatchNorm2D(planes) + self.conv2 = nn.Conv2D(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + self.bn2 = nn.BatchNorm2D(planes) + self.conv3 = nn.Conv2D(planes, + planes * self.expansion, + kernel_size=1, + bias_attr=False) + self.bn3 = nn.BatchNorm2D(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class PReLU_layer(nn.Layer): + def __init__(self, init_value=0.25, num=1): + super(PReLU_layer, self).__init__() + x = self.create_parameter( + attr=None, + shape=[num], + dtype=paddle.get_default_dtype(), + is_bias=False, + default_initializer=nn.initializer.Constant(init_value)) + self.add_parameter('weight', x) + + def forward(self, x): + return F.prelu(x, self.weight) + + +class SEBlock(nn.Layer): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.fc = nn.Sequential(nn.Linear(channel, channel // reduction), + nn.PReLU(), + nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +def constant_init(param, **kwargs): + initializer = nn.initializer.Constant(**kwargs) + initializer(param, param.block) + + +@DISCRIMINATORS.register() +class ResNetArcFace(nn.Layer): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + def __init__(self, block, layers, use_se=True, reprod_logger=None): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + self.conv1 = nn.Conv2D(1, 64, kernel_size=3, padding=1, bias_attr=False) + self.bn1 = nn.BatchNorm2D(64) + self.maxpool = nn.MaxPool2D(kernel_size=2, stride=2) + self.prelu = PReLU_layer() + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2D(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1D(512) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, paddle.nn.Conv2D): + nn.initializer.XavierNormal(m.weight) + elif isinstance(m, paddle.nn.BatchNorm2D) or isinstance( + m, paddle.nn.BatchNorm1D): + constant_init(m.weight, value=1.) + constant_init(m.bias, value=0.) + elif isinstance(m, paddle.nn.Linear): + nn.initializer.XavierNormal(m.weight) + constant_init(m.bias, value=0.) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias_attr=False), + nn.BatchNorm2D(planes * block.expansion)) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, + use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.reshape([x.shape[0], -1]) + x = self.fc5(x) + x = self.bn5(x) + return x diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py old mode 100755 new mode 100644 index 5a14f7b..71834e2 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -43,4 +43,6 @@ from .rcan import RCAN from .prenet import PReNet from .gpen import GPEN from .swinir import SwinIR +from .gfpganv1_clean_arch import GFPGANv1Clean +from .gfpganv1_arch import GFPGANv1, StyleGAN2DiscriminatorGFPGAN from .invdn import InvDN diff --git a/ppgan/models/generators/gfpganv1_arch.py b/ppgan/models/generators/gfpganv1_arch.py new file mode 100644 index 0000000..b9c3345 --- /dev/null +++ b/ppgan/models/generators/gfpganv1_arch.py @@ -0,0 +1,1418 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import random +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + +from ppgan.models.discriminators.builder import DISCRIMINATORS +from ppgan.models.generators.builder import GENERATORS +from ppgan.utils.download import get_path_from_url + + +class StyleGAN2Generator(nn.Layer): + """StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. A cross production will be applied to extent 1D resample + kenrel to 2D resample kernel. Default: (1, 3, 3, 1). + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1): + super(StyleGAN2Generator, self).__init__() + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.append( + EqualLinear(num_style_feat, + num_style_feat, + bias=True, + bias_init_val=0, + lr_mul=lr_mlp, + activation='fused_lrelu')) + self.style_mlp = nn.Sequential(*style_mlp_layers) + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv(channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=resample_kernel) + self.to_rgb1 = ToRGB(channels['4'], + num_style_feat, + upsample=False, + resample_kernel=resample_kernel) + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + self.style_convs = nn.LayerList() + self.to_rgbs = nn.LayerList() + self.noises = nn.Layer() + in_channels = channels['4'] + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + x = paddle.ones(shape=shape, dtype='float32') + self.noises.register_buffer(f'noise{layer_idx}', x) + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2 ** i}'] + self.style_convs.append( + StyleConv(in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample', + resample_kernel=resample_kernel)) + self.style_convs.append( + StyleConv(out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=resample_kernel)) + self.to_rgbs.append( + ToRGB(out_channels, + num_style_feat, + upsample=True, + resample_kernel=resample_kernel)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + x = paddle.ones(shape=[1, 1, 4, 4], dtype='float32') + noises = [x] + for i in range(3, self.log_size + 1): + for _ in range(2): + x = paddle.ones(shape=[1, 1, 2**i, 2**i], dtype='float32') + noises.append(x) + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + x = paddle.ones(shape=[num_latent, self.num_style_feat], + dtype='float32') + latent_in = x + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2Generator. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. + Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is + False. Default: True. + truncation (float): TODO. Default: 1. + truncation_latent (Tensor | None): TODO. Default: None. + inject_index (int | None): The injection index for mixing noise. + Default: None. + return_latents (bool): Whether to return style latents. + Default: False. + """ + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise{i}') + for i in range(self.num_layers) + ] + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * + (style - truncation_latent)) + styles = style_truncation + if len(styles) == 1: + inject_index = self.num_latent + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1) + latent = paddle.tile(latent, repeat_times=[1, inject_index, 1]) + else: + latent = styles[0] + elif len(styles) == 2: + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1) + latent1 = paddle.tile(latent, repeat_times=[1, inject_index, 1]) + + latent2 = styles[1].unsqueeze(1) + latent2 = paddle.tile( + latent2, repeat_times=[1, self.num_latent - inject_index, 1]) + latent = paddle.concat([latent1, latent2], 1) + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + image = skip + if return_latents: + return image, latent + else: + return image, None + + +def var(x, axis=None, unbiased=True, keepdim=False, name=None): + + u = paddle.mean(x, axis, True, name) + out = paddle.sum((x - u) * (x - u), axis, keepdim=keepdim, name=name) + + n = paddle.cast(paddle.numel(x), x.dtype) \ + / paddle.cast(paddle.numel(out), x.dtype) + if unbiased: + one_const = paddle.ones([1], x.dtype) + n = paddle.where(n > one_const, n - 1., one_const) + out /= n + return out + + +@DISCRIMINATORS.register() +class StyleGAN2DiscriminatorGFPGAN(nn.Layer): + """StyleGAN2 Discriminator. + + Args: + out_size (int): The spatial size of outputs. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. A cross production will be applied to extent 1D resample + kenrel to 2D resample kernel. Default: (1, 3, 3, 1). + stddev_group (int): For group stddev statistics. Default: 4. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + def __init__(self, + out_size, + channel_multiplier=2, + resample_kernel=(1, 3, 3, 1), + stddev_group=4, + narrow=1): + super(StyleGAN2DiscriminatorGFPGAN, self).__init__() + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + log_size = int(math.log(out_size, 2)) + conv_body = [ + ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True) + ] + in_channels = channels[f'{out_size}'] + for i in range(log_size, 2, -1): + out_channels = channels[f'{2 ** (i - 1)}'] + conv_body.append( + ResBlock(in_channels, out_channels, resample_kernel)) + in_channels = out_channels + self.conv_body = nn.Sequential(*conv_body) + self.final_conv = ConvLayer(in_channels + 1, + channels['4'], + 3, + bias=True, + activate=True) + self.final_linear = nn.Sequential( + EqualLinear(channels['4'] * 4 * 4, + channels['4'], + bias=True, + bias_init_val=0, + lr_mul=1, + activation='fused_lrelu'), + EqualLinear(channels['4'], + 1, + bias=True, + bias_init_val=0, + lr_mul=1, + activation=None)) + self.stddev_group = stddev_group + self.stddev_feat = 1 + + def forward(self, x): + out = self.conv_body(x) + b, c, h, w = out.shape + group = min(b, self.stddev_group) + stddev = out.reshape( + [group, -1, self.stddev_feat, c // self.stddev_feat, h, w]) + stddev = paddle.sqrt(var(stddev, 0, unbiased=False) + 1e-08) + stddev = stddev.mean(axis=[2, 3, 4], keepdim=True).squeeze(2) + + stddev = paddle.tile(stddev, repeat_times=[group, 1, h, w]) + out = paddle.concat([out, stddev], 1) + out = self.final_conv(out) + out = out.reshape([b, -1]) + out = self.final_linear(out) + return out + + +class StyleGAN2GeneratorSFT(StyleGAN2Generator): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be + applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1). + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorSFT, + self).__init__(out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise{i}') + for i in range(self.num_layers) + ] + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * + (style - truncation_latent)) + styles = style_truncation + if len(styles) == 1: + inject_index = self.num_latent + if styles[0].ndim < 3: + latent = paddle.tile(styles[0].unsqueeze(1), + repeat_times=[1, inject_index, 1]) + else: + latent = styles[0] + elif len(styles) == 2: + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1) + latent1 = paddle.tile(latent, repeat_times=[1, inject_index, 1]) + + latent2 = styles[1].unsqueeze(1) + latent2 = paddle.tile( + latent2, repeat_times=[1, self.num_latent - inject_index, 1]) + latent = paddle.concat([latent1, latent2], 1) + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + if i < len(conditions): + if self.sft_half: + out_same, out_sft = paddle.split(out, 2, axis=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = paddle.concat([out_same, out_sft], axis=1) + else: + out = out * conditions[i - 1] + conditions[i] + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + image = skip + if return_latents: + return image, latent + else: + return image, None + + +@GENERATORS.register() +class GFPGANv1(nn.Layer): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be + applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1). + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + def __init__(self, + out_size, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + decoder_load_path=None, + fix_decoder=True, + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + super(GFPGANv1, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + unet_narrow = narrow * 0.5 + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**int(math.log(out_size, 2)) + self.conv_body_first = ConvLayer(3, + channels[f'{first_out_size}'], + 1, + bias=True, + activate=True) + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.LayerList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2 ** (i - 1)}'] + self.conv_body_down.append( + ResBlock(in_channels, out_channels, resample_kernel)) + in_channels = out_channels + self.final_conv = ConvLayer(in_channels, + channels['4'], + 3, + bias=True, + activate=True) + in_channels = channels['4'] + self.conv_body_up = nn.LayerList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2 ** i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + self.toRGB = nn.LayerList() + for i in range(3, self.log_size + 1): + self.toRGB.append( + EqualConv2d(channels[f'{2 ** i}'], + 3, + 1, + stride=1, + padding=0, + bias=True, + bias_init_val=0)) + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - + 2) * num_style_feat + else: + linear_out_channel = num_style_feat + self.final_linear = EqualLinear(channels['4'] * 4 * 4, + linear_out_channel, + bias=True, + bias_init_val=0, + lr_mul=1, + activation=None) + self.stylegan_decoder = StyleGAN2GeneratorSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + if decoder_load_path: + decoder_load_path = get_path_from_url(decoder_load_path) + self.stylegan_decoder.set_state_dict(paddle.load(decoder_load_path)) + + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.stop_gradient = True + self.condition_scale = nn.LayerList() + self.condition_shift = nn.LayerList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2 ** i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=True, + bias_init_val=0), ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, + sft_out_channels, + 3, + stride=1, + padding=1, + bias=True, + bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=True, + bias_init_val=0), ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, + sft_out_channels, + 3, + stride=1, + padding=1, + bias=True, + bias_init_val=0))) + + def forward(self, + x, + return_latents=False, + return_rgb=True, + randomize_noise=False): + """Forward function for GFPGANv1. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + feat = self.conv_body_first(x) + + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + feat = self.final_conv(feat) + style_code = self.final_linear(feat.reshape([feat.shape[0], -1])) + if self.different_w: + style_code = style_code.reshape( + [style_code.shape[0], -1, self.num_style_feat]) + + for i in range(self.log_size - 2): + feat = feat + unet_skips[i] + feat = self.conv_body_up[i](feat) + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + return image, out_rgbs + + +class FacialComponentDiscriminator(nn.Layer): + """Facial component (eyes, mouth, noise) discriminator used in GFPGAN. + """ + def __init__(self): + super(FacialComponentDiscriminator, self).__init__() + self.conv1 = ConvLayer(3, + 64, + 3, + downsample=False, + resample_kernel=(1, 3, 3, 1), + bias=True, + activate=True) + self.conv2 = ConvLayer(64, + 128, + 3, + downsample=True, + resample_kernel=(1, 3, 3, 1), + bias=True, + activate=True) + self.conv3 = ConvLayer(128, + 128, + 3, + downsample=False, + resample_kernel=(1, 3, 3, 1), + bias=True, + activate=True) + self.conv4 = ConvLayer(128, + 256, + 3, + downsample=True, + resample_kernel=(1, 3, 3, 1), + bias=True, + activate=True) + self.conv5 = ConvLayer(256, + 256, + 3, + downsample=False, + resample_kernel=(1, 3, 3, 1), + bias=True, + activate=True) + self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False) + + def forward(self, x, return_feats=False): + """Forward function for FacialComponentDiscriminator. + + Args: + x (Tensor): Input images. + return_feats (bool): Whether to return intermediate features. Default: False. + """ + feat = self.conv1(x) + feat = self.conv3(self.conv2(feat)) + rlt_feats = [] + if return_feats: + rlt_feats.append(feat.clone()) + feat = self.conv5(self.conv4(feat)) + if return_feats: + rlt_feats.append(feat.clone()) + out = self.final_conv(feat) + if return_feats: + return out, rlt_feats + else: + return out, None + + +class ConvUpLayer(nn.Layer): + """Convolutional upsampling layer. It uses bilinear upsampler + Conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + activate (bool): Whether use activateion. Default: True. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + bias=True, + bias_init_val=0, + activate=True): + super(ConvUpLayer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + self.weight = paddle.create_parameter( + shape=[out_channels, in_channels, kernel_size, kernel_size], + dtype='float32', + default_initializer=paddle.nn.initializer.Normal()) + if bias and not activate: + self.bias = paddle.create_parameter( + shape=[out_channels], + dtype='float32', + default_initializer=paddle.nn.initializer.Constant( + bias_init_val)) + else: + pass + self.bias = None + if activate: + if bias: + self.activation = FusedLeakyReLU(out_channels) + else: + self.activation = ScaledLeakyReLU(0.2) + else: + self.activation = None + + def forward(self, x): + out = F.interpolate(x, + scale_factor=2, + mode='bilinear', + align_corners=False) + out = F.conv2d(out, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding) + if self.activation is not None: + out = self.activation(out) + return out + + +class ResUpBlock(nn.Layer): + """Residual block with upsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + def __init__(self, in_channels, out_channels): + super(ResUpBlock, self).__init__() + self.conv1 = ConvLayer(in_channels, + in_channels, + 3, + bias=True, + activate=True) + self.conv2 = ConvUpLayer(in_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=True, + activate=True) + self.skip = ConvUpLayer(in_channels, + out_channels, + 1, + bias=False, + activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape((-1, in_h, in_w, 1)) + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + out = input.reshape((-1, in_h, 1, in_w, 1, minor)) + out = out.transpose((0, 1, 3, 5, 2, 4)) + out = out.reshape((-1, 1, 1, 1)) + out = F.pad(out, [0, up_x - 1, 0, up_y - 1]) + out = out.reshape((-1, in_h, in_w, minor, up_y, up_x)) + out = out.transpose((0, 3, 1, 4, 2, 5)) + out = out.reshape((-1, minor, in_h * up_y, in_w * up_x)) + out = F.pad( + out, [max(pad_x0, 0), + max(pad_x1, 0), + max(pad_y0, 0), + max(pad_y1, 0)]) + out = out[:, :, + max(-pad_y0, 0):out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[3] - max(-pad_x1, 0)] + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w)) + out = F.conv2d(out, w) + out = out.reshape((-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1)) + out = out.transpose((0, 2, 3, 1)) + out = out[:, ::down_y, ::down_x, :] + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + return out.reshape((-1, channel, out_h, out_w)) + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], + pad[0], pad[1]) + return out + + +class NormStyleCode(nn.Layer): + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * paddle.rsqrt(paddle.mean(x**2, axis=1, keepdim=True) + 1e-08) + + +def make_resample_kernel(k): + """Make resampling kernel for UpFirDn. + + Args: + k (list[int]): A list indicating the 1D resample kernel magnitude. + + Returns: + Tensor: 2D resampled kernel. + """ + k = paddle.to_tensor(k, dtype="float32") + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class UpFirDnUpsample(nn.Layer): + """Upsample, FIR filter, and downsample (upsampole version). + + References: + 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501 + 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501 + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + factor (int): Upsampling scale factor. Default: 2. + """ + def __init__(self, resample_kernel, factor=2): + super(UpFirDnUpsample, self).__init__() + self.kernel = make_resample_kernel(resample_kernel) * factor**2 + self.factor = factor + pad = self.kernel.shape[0] - factor + self.pad = (pad + 1) // 2 + factor - 1, pad // 2 + + def forward(self, x): + out = upfirdn2d(x, self.kernel, up=self.factor, down=1, pad=self.pad) + return out + + def __repr__(self): + return f'{self.__class__.__name__}(factor={self.factor})' + + +class UpFirDnDownsample(nn.Layer): + """Upsample, FIR filter, and downsample (downsampole version). + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + factor (int): Downsampling scale factor. Default: 2. + """ + def __init__(self, resample_kernel, factor=2): + super(UpFirDnDownsample, self).__init__() + self.kernel = make_resample_kernel(resample_kernel) + self.factor = factor + pad = self.kernel.shape[0] - factor + self.pad = (pad + 1) // 2, pad // 2 + + def forward(self, x): + out = upfirdn2d(x, self.kernel, up=1, down=self.factor, pad=self.pad) + return out + + def __repr__(self): + return f'{self.__class__.__name__}(factor={self.factor})' + + +class UpFirDnSmooth(nn.Layer): + """Upsample, FIR filter, and downsample (smooth version). + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + upsample_factor (int): Upsampling scale factor. Default: 1. + downsample_factor (int): Downsampling scale factor. Default: 1. + kernel_size (int): Kernel size: Deafult: 1. + """ + def __init__(self, + resample_kernel, + upsample_factor=1, + downsample_factor=1, + kernel_size=1): + super(UpFirDnSmooth, self).__init__() + self.upsample_factor = upsample_factor + self.downsample_factor = downsample_factor + self.kernel = make_resample_kernel(resample_kernel) + if upsample_factor > 1: + self.kernel = self.kernel * upsample_factor**2 + if upsample_factor > 1: + pad = self.kernel.shape[0] - upsample_factor - (kernel_size - 1) + self.pad = (pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1 + elif downsample_factor > 1: + pad = self.kernel.shape[0] - downsample_factor + (kernel_size - 1) + self.pad = (pad + 1) // 2, pad // 2 + else: + raise NotImplementedError + + def forward(self, x): + out = upfirdn2d(x, self.kernel, up=1, down=1, pad=self.pad) + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}, \ + downsample_factor={self.downsample_factor})') + + +class EqualLinear(nn.Layer): + """This linear layer class stabilizes the learning rate changes of its parameters. + Equalizing learning rate keeps the weights in the network at a similar scale during training. + """ + def __init__(self, + in_dim, + out_dim, + bias=True, + bias_init_val=0, + lr_mul=1, + activation=None): + super().__init__() + + self.weight = paddle.create_parameter( + (in_dim, out_dim), + default_initializer=nn.initializer.Normal(), + dtype='float32') + self.weight.set_value((self.weight / lr_mul)) + + if bias: + self.bias = self.create_parameter( + (out_dim, ), nn.initializer.Constant(bias_init_val)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear(input, + self.weight * self.scale, + bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})" + ) + + +class ModulatedConv2d(nn.Layer): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. + Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + eps (float): A value added to the denominator for numerical stability. + Default: 1e-8. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=(1, 3, 3, 1), + eps=1e-08): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + if self.sample_mode == 'upsample': + self.smooth = UpFirDnSmooth(resample_kernel, + upsample_factor=2, + downsample_factor=1, + kernel_size=kernel_size) + elif self.sample_mode == 'downsample': + self.smooth = UpFirDnSmooth(resample_kernel, + upsample_factor=1, + downsample_factor=2, + kernel_size=kernel_size) + elif self.sample_mode is None: + pass + else: + raise ValueError( + f"Wrong sample mode {self.sample_mode}, supported ones are ['upsample', 'downsample', None]." + ) + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + self.modulation = EqualLinear(num_style_feat, + in_channels, + bias=True, + bias_init_val=1, + lr_mul=1, + activation=None) + self.weight = paddle.create_parameter( + shape=[1, out_channels, in_channels, kernel_size, kernel_size], + dtype='float32', + default_initializer=paddle.nn.initializer.Normal()) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape + style = self.modulation(style).reshape([b, 1, c, 1, 1]) + weight = self.scale * self.weight * style + if self.demodulate: + demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.reshape([b, self.out_channels, 1, 1, 1]) + weight = weight.reshape( + [b * self.out_channels, c, self.kernel_size, self.kernel_size]) + if self.sample_mode == 'upsample': + x = x.reshape([1, b * c, h, w]) + weight = weight.reshape( + [b, self.out_channels, c, self.kernel_size, self.kernel_size]) + weight = weight.transpose([0, 2, 1, 3, 4]).reshape( + [b * c, self.out_channels, self.kernel_size, self.kernel_size]) + out = F.conv2d_transpose(x, weight, padding=0, stride=2, groups=b) + out = out.reshape([b, self.out_channels, *out.shape[2:4]]) + out = self.smooth(out) + elif self.sample_mode == 'downsample': + x = self.smooth(x) + x = x.reshape([1, b * c, *x.shape[2:4]]) + out = F.conv2d(x, weight, padding=0, stride=2, groups=b) + out = out.reshape([b, self.out_channels, *out.shape[2:4]]) + else: + x = x.reshape([1, b * c, h, w]) + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.reshape([b, self.out_channels, *out.shape[2:4]]) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, \ + out_channels={self.out_channels}, \ + kernel_size={self.kernel_size}, \ + demodulate={self.demodulate}, \ + sample_mode={self.sample_mode})') + + +class StyleConv(nn.Layer): + """Style conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=(1, 3, 3, 1)): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d(in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + resample_kernel=resample_kernel) + self.weight = paddle.create_parameter( + shape=[1], + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(0.)) + self.activate = FusedLeakyReLU(out_channels) + + def forward(self, x, style, noise=None): + out = self.modulated_conv(x, style) + if noise is None: + b, _, h, w = out.shape + noise = paddle.normal(shape=[b, 1, h, w]) + out = out + self.weight * noise + out = self.activate(out) + return out + + +class ToRGB(nn.Layer): + """To RGB from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + """ + def __init__(self, + in_channels, + num_style_feat, + upsample=True, + resample_kernel=(1, 3, 3, 1)): + super(ToRGB, self).__init__() + if upsample: + self.upsample = UpFirDnUpsample(resample_kernel, factor=2) + else: + self.upsample = None + self.modulated_conv = ModulatedConv2d(in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None) + self.bias = paddle.create_parameter( + shape=[1, 3, 1, 1], + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(0)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = self.upsample(skip) + out = out + skip + return out + + +class ConstantInput(nn.Layer): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = paddle.create_parameter( + shape=[1, num_channel, size, size], + dtype='float32', + default_initializer=paddle.nn.initializer.Normal()) + + def forward(self, batch): + out = paddle.tile(self.weight, repeat_times=[batch, 1, 1, 1]) + return out + + +class FusedLeakyReLU(nn.Layer): + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5): + super().__init__() + if bias: + self.bias = self.create_parameter( + (channel, ), default_initializer=nn.initializer.Constant(0.0)) + else: + self.bias = None + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, + self.scale) + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): + if bias is not None: + rest_dim = [1] * (len(input.shape) - len(bias.shape) - 1) + return F.leaky_relu(input + bias.reshape([1, bias.shape[0], *rest_dim]), + negative_slope=0.2) * scale + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +class ScaledLeakyReLU(nn.Layer): + """Scaled LeakyReLU. + + Args: + negative_slope (float): Negative slope. Default: 0.2. + """ + def __init__(self, negative_slope=0.2): + super(ScaledLeakyReLU, self).__init__() + self.negative_slope = negative_slope + + def forward(self, x): + out = F.leaky_relu(x, negative_slope=self.negative_slope) + return out * math.sqrt(2) + + +class EqualConv2d(nn.Layer): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. + Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. + Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + bias=True, + bias_init_val=0): + super(EqualConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + x = paddle.ones([out_channels, in_channels, kernel_size, kernel_size], + dtype="float32") + self.weight = paddle.create_parameter( + shape=[out_channels, in_channels, kernel_size, kernel_size], + dtype='float32', + default_initializer=paddle.nn.initializer.Normal()) + if bias: + self.bias = paddle.create_parameter( + shape=[out_channels], + dtype='float32', + default_initializer=paddle.nn.initializer.Constant( + bias_init_val)) + else: + pass + self.bias = None + + def forward(self, x): + out = F.conv2d(x, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, \ + out_channels={self.out_channels}, kernel_size={self.kernel_size}, \ + stride={self.stride}, padding={self.padding}, \ + bias={self.bias is not None})') + + +class ConvLayer(nn.Sequential): + """Conv Layer used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Kernel size. + downsample (bool): Whether downsample by a factor of 2. + Default: False. + resample_kernel (list[int]): A list indicating the 1D resample + kernel magnitude. A cross production will be applied to + extent 1D resample kenrel to 2D resample kernel. + Default: (1, 3, 3, 1). + bias (bool): Whether with bias. Default: True. + activate (bool): Whether use activateion. Default: True. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + downsample=False, + resample_kernel=(1, 3, 3, 1), + bias=True, + activate=True): + layers = [] + if downsample: + layers.append( + UpFirDnSmooth(resample_kernel, + upsample_factor=1, + downsample_factor=2, + kernel_size=kernel_size)) + stride = 2 + self.padding = 0 + else: + stride = 1 + self.padding = kernel_size // 2 + layers.append( + EqualConv2d(in_channels, + out_channels, + kernel_size, + stride=stride, + padding=self.padding, + bias=bias and not activate)) + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channels)) + else: + layers.append(ScaledLeakyReLU(0.2)) + super(ConvLayer, self).__init__(*layers) + + +class ResBlock(nn.Layer): + """Residual block used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + resample_kernel (list[int]): A list indicating the 1D resample + kernel magnitude. A cross production will be applied to + extent 1D resample kenrel to 2D resample kernel. + Default: (1, 3, 3, 1). + """ + def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)): + super(ResBlock, self).__init__() + self.conv1 = ConvLayer(in_channels, + in_channels, + 3, + bias=True, + activate=True) + self.conv2 = ConvLayer(in_channels, + out_channels, + 3, + downsample=True, + resample_kernel=resample_kernel, + bias=True, + activate=True) + self.skip = ConvLayer(in_channels, + out_channels, + 1, + downsample=True, + resample_kernel=resample_kernel, + bias=False, + activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2.) + return out diff --git a/ppgan/models/generators/gfpganv1_clean_arch.py b/ppgan/models/generators/gfpganv1_clean_arch.py new file mode 100644 index 0000000..6568f62 --- /dev/null +++ b/ppgan/models/generators/gfpganv1_clean_arch.py @@ -0,0 +1,329 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import random + +import paddle +from paddle import nn +from paddle.nn import functional as F + +from ppgan.models.generators.stylegan2_clean_arch import StyleGAN2GeneratorClean +from ppgan.models.generators.builder import GENERATORS + + +class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorCSFT, + self).__init__(out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorCSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise{i}') + for i in range(self.num_layers) + ] + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * + (style - truncation_latent)) + styles = style_truncation + if len(styles) == 1: + inject_index = self.num_latent + if styles[0].ndim < 3: + latent = paddle.tile(styles[0].unsqueeze(1), + repeat_times=[1, inject_index, 1]) + else: + latent = styles[0] + elif len(styles) == 2: + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = paddle.tile(styles[0].unsqueeze(1), + repeat_times=[1, inject_index, 1]) + latent2 = paddle.tile( + styles[1].unsqueeze(1), + repeat_times=[1, self.num_latent - inject_index, 1]) + latent = paddle.concat([latent1, latent2], axis=1) + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + if i < len(conditions): + if self.sft_half: + out_same, out_sft = paddle.split(out, 2, axis=1) + + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = paddle.concat([out_same, out_sft], axis=1) + else: + out = out * conditions[i - 1] + conditions[i] + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + image = skip + if return_latents: + + return image, latent + else: + return image, None + + +class ResBlock(nn.Layer): + """Residual block with bilinear upsampling/downsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + mode (str): Upsampling/downsampling mode. Options: down | up. Default: down. + """ + def __init__(self, in_channels, out_channels, mode='down'): + super(ResBlock, self).__init__() + self.conv1 = nn.Conv2D(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2D(in_channels, out_channels, 3, 1, 1) + self.skip = nn.Conv2D(in_channels, out_channels, 1, bias_attr=False) + if mode == 'down': + self.scale_factor = 0.5 + elif mode == 'up': + self.scale_factor = 2 + + def forward(self, x): + out = paddle.nn.functional.leaky_relu(self.conv1(x), negative_slope=0.2) + out = F.interpolate(out, scale_factor=self.scale_factor, mode=\ + 'bilinear', align_corners=False) + out = paddle.nn.functional.leaky_relu(self.conv2(out), + negative_slope=0.2) + x = F.interpolate(x, scale_factor=self.scale_factor, mode=\ + 'bilinear', align_corners=False) + skip = self.skip(x) + out = out + skip + return out + + +def debug(x): + print(type(x)) + if isinstance(x, list): + for i, v in enumerate(x): + print(i, v.shape) + else: + print(0, x.shape) + + +@GENERATORS.register() +class GFPGANv1Clean(nn.Layer): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + def __init__(self, + out_size, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + num_mlp=8, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + super(GFPGANv1Clean, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + unet_narrow = narrow * 0.5 + print("unet_narrow", unet_narrow, "channel_multiplier", + channel_multiplier) + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**int(math.log(out_size, 2)) + self.conv_body_first = nn.Conv2D(3, channels[f'{first_out_size}'], 1) + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.LayerList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2 ** (i - 1)}'] + self.conv_body_down.append( + ResBlock(in_channels, out_channels, mode='down')) + in_channels = out_channels + self.final_conv = nn.Conv2D(in_channels, channels['4'], 3, 1, 1) + in_channels = channels['4'] + self.conv_body_up = nn.LayerList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2 ** i}'] + self.conv_body_up.append( + ResBlock(in_channels, out_channels, mode='up')) + in_channels = out_channels + self.toRGB = nn.LayerList() + for i in range(3, self.log_size + 1): + self.toRGB.append(nn.Conv2D(channels[f'{2 ** i}'], 3, 1)) + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - + 2) * num_style_feat + else: + linear_out_channel = num_style_feat + self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel) + self.stylegan_decoder = StyleGAN2GeneratorCSFT(out_size=out_size, + num_style_feat=num_style_feat, num_mlp=num_mlp, + channel_multiplier=channel_multiplier, narrow=narrow, sft_half=\ + sft_half) + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + paddle.load(decoder_load_path)['params_ema']) + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + self.condition_scale = nn.LayerList() + self.condition_shift = nn.LayerList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2 ** i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + nn.Conv2D(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2D(out_channels, sft_out_channels, 3, 1, 1))) + self.condition_shift.append( + nn.Sequential( + nn.Conv2D(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2D(out_channels, sft_out_channels, 3, 1, 1))) + + def forward(self, + x, + return_latents=False, + return_rgb=True, + randomize_noise=True): + """Forward function for GFPGANv1Clean. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + feat = paddle.nn.functional.leaky_relu(self.conv_body_first(x), + negative_slope=0.2) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + feat = paddle.nn.functional.leaky_relu(self.final_conv(feat), + negative_slope=0.2) + style_code = self.final_linear(feat.reshape([feat.shape[0], -1])) + if self.different_w: + style_code = style_code.reshape( + [style_code.shape[0], -1, self.num_style_feat]) + for i in range(self.log_size - 2): + feat = feat + unet_skips[i] + feat = self.conv_body_up[i](feat) + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + image, _ = self.stylegan_decoder(styles=[style_code], + conditions=conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + if return_latents: + return image, _ + else: + return image, out_rgbs diff --git a/ppgan/models/generators/stylegan2_clean_arch.py b/ppgan/models/generators/stylegan2_clean_arch.py new file mode 100644 index 0000000..50f66d9 --- /dev/null +++ b/ppgan/models/generators/stylegan2_clean_arch.py @@ -0,0 +1,396 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import random + +import paddle +from paddle import nn +from paddle.nn import functional as F + + +class NormStyleCode(nn.Layer): + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * paddle.rsqrt(paddle.mean(x ** 2, axis=1, keepdim=\ + True) + 1e-08) + + +class ModulatedConv2d(nn.Layer): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + eps (float): A value added to the denominator for numerical stability. Default: 1e-8. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-08): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + self.modulation = nn.Linear(num_style_feat, in_channels, bias_attr=True) + # default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, + # mode='fan_in', nonlinearity='linear') + x=paddle.randn(shape=[1, out_channels, in_channels, kernel_size, kernel_size],dtype='float32')/math. \ + sqrt(in_channels * kernel_size ** 2) + + self.weight = paddle.create_parameter( + shape=x.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Assign(x)) + self.weight.stop_gradient = False + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape + style = self.modulation(style).reshape([b, 1, c, 1, 1]) + weight = self.weight * style + if self.demodulate: + demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.reshape([b, self.out_channels, 1, 1, 1]) + weight = weight.reshape( + [b * self.out_channels, c, self.kernel_size, self.kernel_size]) + if self.sample_mode == 'upsample': + x = F.interpolate(x, + scale_factor=2, + mode='bilinear', + align_corners=False) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, + scale_factor=0.5, + mode='bilinear', + align_corners=False) + b, c, h, w = x.shape + x = x.reshape([1, b * c, h, w]) + out = paddle.nn.functional.conv2d(x, + weight, + padding=self.padding, + groups=b) + out = out.reshape([b, self.out_channels, *out.shape[2:4]]) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, \ + out_channels={self.out_channels}, \ + kernel_size={self.kernel_size}, \ + demodulate={self.demodulate}, \ + sample_mode={self.sample_mode})') + + +class StyleConv(nn.Layer): + """Style conv used in StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d(in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode) + + x = paddle.zeros([1], dtype="float32") + self.weight = paddle.create_parameter( + x.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Assign( + x)) # for noise injection + x = paddle.zeros([1, out_channels, 1, 1], dtype="float32") + self.bias = paddle.create_parameter( + x.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Assign(x)) + self.activate = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x, style, noise=None): + out = self.modulated_conv(x, style) * 2**0.5 + if noise is None: + b, _, h, w = out.shape + noise = paddle.normal(shape=[b, 1, h, w]) + out = out + self.weight * noise + out = out + self.bias + out = self.activate(out) + return out + + +class ToRGB(nn.Layer): + """To RGB (image space) from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + def __init__(self, in_channels, num_style_feat, upsample=True): + super(ToRGB, self).__init__() + self.upsample = upsample + self.modulated_conv = ModulatedConv2d(in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None) + x = paddle.zeros(shape=[1, 3, 1, 1], dtype='float32') + self.bias = paddle.create_parameter( + shape=x.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Assign(x)) + self.bias.stop_gradient = False + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate(skip, + scale_factor=2, + mode='bilinear', + align_corners=False) + out = out + skip + return out + + +class ConstantInput(nn.Layer): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + x = paddle.randn(shape=[1, num_channel, size, size], dtype='float32') + self.weight = paddle.create_parameter( + shape=x.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Assign(x)) + self.weight.stop_gradient = False + + def forward(self, batch): + out = paddle.tile(self.weight, repeat_times=[batch, 1, 1, 1]) + return out + + +class StyleGAN2GeneratorClean(nn.Layer): + """Clean version of StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + narrow=1): + super(StyleGAN2GeneratorClean, self).__init__() + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.extend([ + nn.Linear(num_style_feat, num_style_feat, bias_attr=True), + nn.LeakyReLU(negative_slope=0.2) + ]) + self.style_mlp = nn.Sequential(*style_mlp_layers) + # default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, + # mode='fan_in', nonlinearity='leaky_relu') + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv(channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False) + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + self.style_convs = nn.LayerList() + self.to_rgbs = nn.LayerList() + self.noises = nn.Layer() + in_channels = channels['4'] + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', + paddle.randn(shape=shape)) + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2 ** i}'] + self.style_convs.append(StyleConv(in_channels, out_channels, + kernel_size=3, num_style_feat=num_style_feat, demodulate=\ + True, sample_mode='upsample')) + self.style_convs.append(StyleConv(out_channels, out_channels, + kernel_size=3, num_style_feat=num_style_feat, demodulate=\ + True, sample_mode=None)) + self.to_rgbs.append( + ToRGB(out_channels, num_style_feat, upsample=True)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [paddle.randn(shape=[1, 1, 4, 4])] + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(paddle.randn(shape=[1, 1, 2**i, 2**i])) + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = paddle.randn(shape=[num_latent, self.num_style_feat]) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorClean. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise{i}') + for i in range(self.num_layers) + ] + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * + (style - truncation_latent)) + styles = style_truncation + if len(styles) == 1: + inject_index = self.num_latent + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: + latent = styles[0] + elif len(styles) == 2: + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat( + 1, self.num_latent - inject_index, 1) + latent = paddle.concat([latent1, latent2], axis=1) + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + image = skip + if return_latents: + return image, latent + else: + return image, None diff --git a/ppgan/models/gfpgan_model.py b/ppgan/models/gfpgan_model.py new file mode 100644 index 0000000..9bde5e4 --- /dev/null +++ b/ppgan/models/gfpgan_model.py @@ -0,0 +1,552 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import sys +import paddle +from paddle.nn import functional as F +from paddle import autograd + +from .base_model import BaseModel +from .builder import MODELS +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from .criterions.builder import build_criterion +from ..modules.init import init_weights +from collections import OrderedDict +from ..solver import build_lr_scheduler, build_optimizer +from ppgan.utils.visual import * +from ppgan.models.generators.gfpganv1_arch import FacialComponentDiscriminator +from ppgan.utils.download import get_path_from_url + + +@MODELS.register() +class GFPGANModel(BaseModel): + """ This class implements the gfpgan model. + + """ + def __init__(self, **opt): + + super(GFPGANModel, self).__init__() + self.opt = opt + train_opt = opt + if 'image_visual' in self.opt['path']: + self.image_paths = self.opt['path']['image_visual'] + self.current_iter = 0 + self.nets['net_g'] = build_generator(opt['network_g']) + self.log_size = int(math.log(self.opt['network_g']['out_size'], 2)) + # define networks (both generator and discriminator) + self.nets['net_g_ema'] = build_generator(self.opt['network_g']) + self.nets['net_d'] = build_discriminator(self.opt['network_d']) + self.nets['net_g_ema'].eval() + pretrain_network_g = self.opt['path'].get('pretrain_network_g', None) + if pretrain_network_g != None: + t_weight = get_path_from_url(pretrain_network_g) + t_weight = paddle.load(t_weight) + if 'net_g' in t_weight: + self.nets['net_g'].set_state_dict(t_weight['net_g']) + self.nets['net_g_ema'].set_state_dict(t_weight['net_g_ema']) + else: + self.nets['net_g'].set_state_dict(t_weight) + self.nets['net_g_ema'].set_state_dict(t_weight) + + del t_weight + + self.nets['net_d'].train() + self.nets['net_g'].train() + if ('network_d_left_eye' in self.opt + and 'network_d_right_eye' in self.opt + and 'network_d_mouth' in self.opt): + self.use_facial_disc = True + else: + self.use_facial_disc = False + + if self.use_facial_disc: + # left eye + self.nets['net_d_left_eye'] = FacialComponentDiscriminator() + self.nets['net_d_right_eye'] = FacialComponentDiscriminator() + self.nets['net_d_mouth'] = FacialComponentDiscriminator() + load_path = self.opt['path'].get('pretrain_network_d_left_eye') + if load_path is not None: + load_val = get_path_from_url(load_path) + load_val = paddle.load(load_val) + self.nets['net_d_left_eye'].set_state_dict(load_val) + self.nets['net_d_right_eye'].set_state_dict(load_val) + self.nets['net_d_mouth'].set_state_dict(load_val) + del load_val + self.nets['net_d_left_eye'].train() + self.nets['net_d_right_eye'].train() + self.nets['net_d_mouth'].train() + self.cri_component = build_criterion(train_opt['gan_component_opt']) + + if train_opt.get('pixel_opt'): + self.cri_pix = build_criterion(train_opt['pixel_opt']) + else: + self.cri_pix = None + + # perceptual loss + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_criterion(train_opt['perceptual_opt']) + else: + self.cri_perceptual = None + + # L1 loss is used in pyramid loss, component style loss and identity loss + self.cri_l1 = build_criterion(train_opt['L1_opt']) + + # gan loss (wgan) + self.cri_gan = build_criterion(train_opt['gan_opt']) + + # ----------- define identity loss ----------- # + if 'network_identity' in self.opt: + self.use_identity = True + else: + self.use_identity = False + + if self.use_identity: + # define identity network + self.network_identity = build_discriminator( + self.opt['network_identity']) + load_path = self.opt['path'].get('pretrain_network_identity') + if load_path is not None: + load_val = get_path_from_url(load_path) + load_val = paddle.load(load_val) + self.network_identity.set_state_dict(load_val) + del load_val + self.network_identity.eval() + for param in self.network_identity.parameters(): + param.stop_gradient = True + + # regularization weights + self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + self.net_d_reg_every = train_opt['net_d_reg_every'] + + def setup_input(self, data): + self.lq = data['lq'] + + if 'gt' in data: + self.gt = data['gt'] + + if 'loc_left_eye' in data: + # get facial component locations, shape (batch, 4) + self.loc_left_eyes = data['loc_left_eye'].astype('float32') + self.loc_right_eyes = data['loc_right_eye'].astype('float32') + self.loc_mouths = data['loc_mouth'].astype('float32') + + def forward(self, test_mode=False, regularize=False): + pass + + def train_iter(self, optimizers=None): + # optimize nets['net_g'] + for p in self.nets['net_d'].parameters(): + p.stop_gradient = True + self.optimizers['optim_g'].clear_grad(set_to_zero=False) + + # do not update facial component net_d + if self.use_facial_disc: + for p in self.nets['net_d_left_eye'].parameters(): + p.stop_gradient = True + for p in self.nets['net_d_right_eye'].parameters(): + p.stop_gradient = True + for p in self.nets['net_d_mouth'].parameters(): + p.stop_gradient = True + + # image pyramid loss weight + pyramid_loss_weight = self.opt.get('pyramid_loss_weight', 0) + if pyramid_loss_weight > 0 and self.current_iter > self.opt.get( + 'remove_pyramid_loss', float('inf')): + pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error + if pyramid_loss_weight > 0: + self.output, out_rgbs = self.nets['net_g'](self.lq, return_rgb=True) + pyramid_gt = self.construct_img_pyramid() + else: + self.output, out_rgbs = self.nets['net_g'](self.lq, + return_rgb=False) + + # get roi-align regions + if self.use_facial_disc: + self.get_roi_regions(eye_out_size=80, mouth_out_size=120) + l_g_total = 0 + if (self.current_iter % self.net_d_iters == 0 + and self.current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + self.losses['l_g_pix'] = l_g_pix + + # image pyramid loss + if pyramid_loss_weight > 0: + for i in range(0, self.log_size - 2): + l_pyramid = self.cri_l1(out_rgbs[i], + pyramid_gt[i]) * pyramid_loss_weight + l_g_total += l_pyramid + self.losses[f'l_p_{2**(i+3)}'] = l_pyramid + + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual( + self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + self.losses['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + self.losses['l_g_style'] = l_g_style + + # gan loss + fake_g_pred = self.nets['net_d'](self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + self.losses['l_g_gan'] = l_g_gan + + # facial component loss + if self.use_facial_disc: + # left eye + fake_left_eye, fake_left_eye_feats = self.nets[ + 'net_d_left_eye'](self.left_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False) + l_g_total += l_g_gan + self.losses['l_g_gan_left_eye'] = l_g_gan + # right eye + fake_right_eye, fake_right_eye_feats = self.nets[ + 'net_d_right_eye'](self.right_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_right_eye, + True, + is_disc=False) + l_g_total += l_g_gan + self.losses['l_g_gan_right_eye'] = l_g_gan + # mouth + fake_mouth, fake_mouth_feats = self.nets['net_d_mouth']( + self.mouths, return_feats=True) + l_g_gan = self.cri_component(fake_mouth, True, is_disc=False) + l_g_total += l_g_gan + self.losses['l_g_gan_mouth'] = l_g_gan + + if self.opt.get('comp_style_weight', 0) > 0: + # get gt feat + _, real_left_eye_feats = self.nets['net_d_left_eye']( + self.left_eyes_gt, return_feats=True) + _, real_right_eye_feats = self.nets['net_d_right_eye']( + self.right_eyes_gt, return_feats=True) + _, real_mouth_feats = self.nets['net_d_mouth']( + self.mouths_gt, return_feats=True) + + def _comp_style(feat, feat_gt, criterion): + return criterion(self._gram_mat( + feat[0]), self._gram_mat( + feat_gt[0].detach())) * 0.5 + criterion( + self._gram_mat(feat[1]), + self._gram_mat(feat_gt[1].detach())) + + # facial component style loss + comp_style_loss = 0 + comp_style_loss += _comp_style(fake_left_eye_feats, + real_left_eye_feats, + self.cri_l1) + comp_style_loss += _comp_style(fake_right_eye_feats, + real_right_eye_feats, + self.cri_l1) + comp_style_loss += _comp_style(fake_mouth_feats, + real_mouth_feats, + self.cri_l1) + comp_style_loss = comp_style_loss * self.opt[ + 'comp_style_weight'] + l_g_total += comp_style_loss + self.losses['l_g_comp_style_loss'] = comp_style_loss + + # identity loss + if self.use_identity: + identity_weight = self.opt['identity_weight'] + # get gray images and resize + out_gray = self.gray_resize_for_identity(self.output) + gt_gray = self.gray_resize_for_identity(self.gt) + + identity_gt = self.network_identity(gt_gray).detach() + identity_out = self.network_identity(out_gray) + l_identity = self.cri_l1(identity_out, + identity_gt) * identity_weight + l_g_total += l_identity + self.losses['l_identity'] = l_identity + + l_g_total.backward() + self.optimizers['optim_g'].step() + # EMA + self.accumulate(self.nets['net_g_ema'], + self.nets['net_g'], + decay=0.5**(32 / (10 * 1000))) + + # ----------- optimize net_d ----------- # + for p in self.nets['net_d'].parameters(): + p.stop_gradient = False + self.optimizers['optim_d'].clear_grad(set_to_zero=False) + if self.use_facial_disc: + for p in self.nets['net_d_left_eye'].parameters(): + p.stop_gradient = False + for p in self.nets['net_d_right_eye'].parameters(): + p.stop_gradient = False + for p in self.nets['net_d_mouth'].parameters(): + p.stop_gradient = False + self.optimizers['optim_net_d_left_eye'].clear_grad( + set_to_zero=False) + self.optimizers['optim_net_d_right_eye'].clear_grad( + set_to_zero=False) + self.optimizers['optim_net_d_mouth'].clear_grad(set_to_zero=False) + fake_d_pred = self.nets['net_d'](self.output.detach()) + real_d_pred = self.nets['net_d'](self.gt) + + l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + self.losses['l_d'] = l_d + # In WGAN, real_score should be positive and fake_score should be negative + self.losses['real_score'] = real_d_pred.detach().mean() + self.losses['fake_score'] = fake_d_pred.detach().mean() + l_d.backward() + if self.current_iter % self.net_d_reg_every == 0: + self.gt.stop_gradient = False + real_pred = self.nets['net_d'](self.gt) + l_d_r1 = r1_penalty(real_pred, self.gt) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + + 0 * real_pred[0]) + self.losses['l_d_r1'] = l_d_r1.detach().mean() + l_d_r1.backward() + + self.optimizers['optim_d'].step() + + # optimize facial component discriminators + if self.use_facial_disc: + # left eye + fake_d_pred, _ = self.nets['net_d_left_eye']( + self.left_eyes.detach()) + real_d_pred, _ = self.nets['net_d_left_eye'](self.left_eyes_gt) + l_d_left_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + self.losses['l_d_left_eye'] = l_d_left_eye + l_d_left_eye.backward() + # right eye + fake_d_pred, _ = self.nets['net_d_right_eye']( + self.right_eyes.detach()) + real_d_pred, _ = self.nets['net_d_right_eye'](self.right_eyes_gt) + l_d_right_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + self.losses['l_d_right_eye'] = l_d_right_eye + l_d_right_eye.backward() + # mouth + fake_d_pred, _ = self.nets['net_d_mouth'](self.mouths.detach()) + real_d_pred, _ = self.nets['net_d_mouth'](self.mouths_gt) + l_d_mouth = self.cri_component(real_d_pred, True, + is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + self.losses['l_d_mouth'] = l_d_mouth + l_d_mouth.backward() + + self.optimizers['optim_net_d_left_eye'].step() + self.optimizers['optim_net_d_right_eye'].step() + self.optimizers['optim_net_d_mouth'].step() + # if self.current_iter%1000==0: + + def test_iter(self, metrics=None): + self.nets['net_g_ema'].eval() + self.fake_img, _ = self.nets['net_g_ema'](self.lq) + self.visual_items['cur_fake'] = self.fake_img[0] + self.visual_items['cur_gt'] = self.gt[0] + self.visual_items['cur_lq'] = self.lq[0] + with paddle.no_grad(): + if metrics is not None: + for metric in metrics.values(): + metric.update(self.fake_img.detach().numpy(), + self.gt.detach().numpy()) + + def setup_lr_schedulers(self, cfg): + self.lr_scheduler = OrderedDict() + self.lr_scheduler['_g'] = build_lr_scheduler(cfg) + self.lr_scheduler['_component'] = build_lr_scheduler(cfg) + cfg_d = cfg.copy() + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + cfg_d['learning_rate'] *= net_d_reg_ratio + self.lr_scheduler['_d'] = build_lr_scheduler(cfg_d) + return self.lr_scheduler + + def setup_optimizers(self, lr, cfg): + # ----------- optimizer g ----------- # + net_g_reg_ratio = 1 + parameters = [] + parameters += self.nets['net_g'].parameters() + cfg['optim_g']['beta1'] = 0**net_g_reg_ratio + cfg['optim_g']['beta2'] = 0.99**net_g_reg_ratio + + self.optimizers['optim_g'] = build_optimizer(cfg['optim_g'], + self.lr_scheduler['_g'], + parameters) + + # ----------- optimizer d ----------- # + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + parameters = [] + parameters += self.nets['net_d'].parameters() + cfg['optim_d']['beta1'] = 0**net_d_reg_ratio + cfg['optim_d']['beta2'] = 0.99**net_d_reg_ratio + + self.optimizers['optim_d'] = build_optimizer(cfg['optim_d'], + self.lr_scheduler['_d'], + parameters) + + # ----------- optimizers for facial component networks ----------- # + if self.use_facial_disc: + parameters = [] + parameters += self.nets['net_d_left_eye'].parameters() + + self.optimizers['optim_net_d_left_eye'] = build_optimizer( + cfg['optim_component'], self.lr_scheduler['_component'], + parameters) + + parameters = [] + parameters += self.nets['net_d_right_eye'].parameters() + + self.optimizers['optim_net_d_right_eye'] = build_optimizer( + cfg['optim_component'], self.lr_scheduler['_component'], + parameters) + + parameters = [] + parameters += self.nets['net_d_mouth'].parameters() + + self.optimizers['optim_net_d_mouth'] = build_optimizer( + cfg['optim_component'], self.lr_scheduler['_component'], + parameters) + + return self.optimizers + + def construct_img_pyramid(self): + """Construct image pyramid for intermediate restoration loss""" + pyramid_gt = [self.gt] + down_img = self.gt + for _ in range(0, self.log_size - 3): + down_img = F.interpolate(down_img, + scale_factor=0.5, + mode='bilinear', + align_corners=False) + pyramid_gt.insert(0, down_img) + return pyramid_gt + + def get_roi_regions(self, eye_out_size=80, mouth_out_size=120): + from paddle.vision.ops import roi_align + face_ratio = int(self.opt['network_g']['out_size'] / 512) + eye_out_size *= face_ratio + mouth_out_size *= face_ratio + + rois_eyes = [] + rois_mouths = [] + num_eye = [] + num_mouth = [] + for b in range(self.loc_left_eyes.shape[0]): # loop for batch size + # left eye and right eye + + img_inds = paddle.ones([2, 1], dtype=self.loc_left_eyes.dtype) * b + bbox = paddle.stack( + [self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], + axis=0) # shape: (2, 4) + # rois = paddle.concat([img_inds, bbox], axis=-1) # shape: (2, 5) + rois_eyes.append(bbox) + # mouse + img_inds = paddle.ones([1, 1], dtype=self.loc_left_eyes.dtype) * b + num_eye.append(2) + num_mouth.append(1) + # rois = paddle.concat([img_inds, self.loc_mouths[b:b + 1, :]], axis=-1) # shape: (1, 5) + rois_mouths.append(self.loc_mouths[b:b + 1, :]) + rois_eyes = paddle.concat(rois_eyes, 0) + rois_mouths = paddle.concat(rois_mouths, 0) + # real images + num_eye = paddle.to_tensor(num_eye, dtype='int32') + num_mouth = paddle.to_tensor(num_mouth, dtype='int32') + + all_eyes = roi_align(self.gt, + boxes=rois_eyes, + boxes_num=num_eye, + output_size=eye_out_size, + aligned=False) * face_ratio + self.left_eyes_gt = all_eyes[0::2, :, :, :] + self.right_eyes_gt = all_eyes[1::2, :, :, :] + self.mouths_gt = roi_align(self.gt, + boxes=rois_mouths, + boxes_num=num_mouth, + output_size=mouth_out_size, + aligned=False) * face_ratio + # output + all_eyes = roi_align(self.output, + boxes=rois_eyes, + boxes_num=num_eye, + output_size=eye_out_size, + aligned=False) * face_ratio + self.left_eyes = all_eyes[0::2, :, :, :] + self.right_eyes = all_eyes[1::2, :, :, :] + self.mouths = roi_align(self.output, + boxes=rois_mouths, + boxes_num=num_mouth, + output_size=mouth_out_size, + aligned=False) * face_ratio + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (paddle.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + paddle.Tensor: Gram matrix. + """ + n, c, h, w = x.shape + features = x.reshape((n, c, w * h)) + features_t = features.transpose([0, 2, 1]) + gram = features.bmm(features_t) / (c * h * w) + return gram + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), + mode='bilinear', + align_corners=False) + return out_gray + + def accumulate(self, model1, model2, decay=0.999): + par1 = dict(model1.state_dict()) + par2 = dict(model2.state_dict()) + + for k in par1.keys(): + par1[k] = par1[k] * decay + par2[k] * (1 - decay) + + model1.load_dict(par1) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = paddle.grad(outputs=real_pred.sum(), + inputs=real_img, + create_graph=True)[0] + grad_penalty = grad_real.pow(2).reshape( + (grad_real.shape[0], -1)).sum(1).mean() + return grad_penalty diff --git a/ppgan/utils/gfpgan_tools.py b/ppgan/utils/gfpgan_tools.py new file mode 100644 index 0000000..80d59b7 --- /dev/null +++ b/ppgan/utils/gfpgan_tools.py @@ -0,0 +1,1127 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cv2 +import math +import numpy as np +import random +import os +import os.path as osp +from abc import ABCMeta +from abc import abstractmethod +import paddle +import paddle.nn.functional as F +from paddle.vision.transforms.functional import normalize + + +def _blend(img1, img2, ratio): + ratio = float(ratio) + bound = 1.0 if paddle.is_floating_point(img1) else 255.0 + return (ratio * img1 + (1.0 - ratio) * img2).clip(0, bound) + + +def _get_image_num_channels(img): + if img.ndim == 2: + return 1 + elif img.ndim > 2: + return img.shape[-3] + + raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) + + +def _rgb2hsv(img): + r, g, b = img.unbind(axis=-3) + + # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ + # src/libImaging/Convert.c#L330 + maxc = paddle.max(img, axis=-3) + minc = paddle.min(img, axis=-3) + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occuring so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + cr = maxc - minc + # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. + ones = paddle.ones_like(maxc) + s = cr / paddle.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + cr_divisor = paddle.where(eqc, ones, cr) + rc = (maxc - r) / cr_divisor + gc = (maxc - g) / cr_divisor + bc = (maxc - b) / cr_divisor + + t_zero = paddle.zeros_like(bc) + hr = paddle.where(maxc == r, (bc - gc), t_zero) + hg = paddle.where((maxc == g) & (maxc != r), (2.0 + rc - bc), t_zero) + hb = paddle.where((maxc != g) & (maxc != r), (4.0 + gc - rc), t_zero) + + h = (hr + hg + hb) + h = paddle.mod((h / 6.0 + 1.0), paddle.to_tensor([1.0])) + return paddle.stack((h, s, maxc), axis=-3) + + +def _hsv2rgb(img): + h, s, v = img.unbind(axis=-3) + i = paddle.floor(h * 6.0) + f = (h * 6.0) - i + i = paddle.cast(i, dtype='int32') + + p = paddle.clip((v * (1.0 - s)), 0.0, 1.0) + q = paddle.clip((v * (1.0 - s * f)), 0.0, 1.0) + t = paddle.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + + mask = i.unsqueeze(axis=-3) == paddle.arange(6).reshape([-1, 1, 1]) + + a1 = paddle.stack((v, q, p, p, t, v), axis=-3) + a2 = paddle.stack((t, v, v, q, p, p), axis=-3) + a3 = paddle.stack((p, p, t, v, v, q), axis=-3) + a4 = paddle.stack((a1, a2, a3), axis=-4) + t_zero = paddle.zeros_like(mask, dtype='float32') + t_ones = paddle.ones_like(mask, dtype='float32') + mask = paddle.where(mask, t_ones, t_zero) + return paddle.einsum("...ijk, ...xijk -> ...xjk", mask, a4) + + +def rgb_to_grayscale(img, num_output_channels=1): + if img.ndim < 3: + raise TypeError( + "Input image tensor should have at least 3 axisensions, but found {}" + .format(img.ndim)) + + if num_output_channels not in (1, 3): + raise ValueError('num_output_channels should be either 1 or 3') + + r, g, b = img.unbind(axis=-3) + l_img = (0.2989 * r + 0.587 * g + 0.114 * b) + l_img = l_img.unsqueeze(axis=-3) + + if num_output_channels == 3: + return l_img.expand(img.shape) + + return l_img + + +def adjust_brightness(img, brightness_factor): + if brightness_factor < 0: + raise ValueError('brightness_factor ({}) is not non-negative.'.format( + brightness_factor)) + + return _blend(img, paddle.zeros_like(img), brightness_factor) + + +def adjust_contrast(img, contrast_factor): + if contrast_factor < 0: + raise ValueError( + 'contrast_factor ({}) is not non-negative.'.format(contrast_factor)) + + dtype = img.dtype if paddle.is_floating_point(img) else paddle.float32 + mean = paddle.mean(paddle.cast(rgb_to_grayscale(img), dtype=dtype), + axis=(-3, -2, -1), + keepdim=True) + + return _blend(img, mean, contrast_factor) + + +def adjust_hue(img, hue_factor): + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError( + 'hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + + if not (isinstance(img, paddle.Tensor)): + raise TypeError('Input img should be Tensor image') + + if _get_image_num_channels(img) == 1: # Match PIL behaviour + return img + + orig_dtype = img.dtype + if img.dtype == paddle.uint8: + img = paddle.cast(img, dtype='float32') / 255.0 + + img = _rgb2hsv(img) + h, s, v = img.unbind(axis=-3) + h = (h + hue_factor) % 1.0 + img = paddle.stack((h, s, v), axis=-3) + img_hue_adj = _hsv2rgb(img) + + if orig_dtype == paddle.uint8: + img_hue_adj = paddle.cast(img_hue_adj * 255.0, dtype=orig_dtype) + + return img_hue_adj + + +def adjust_saturation(img, saturation_factor): + if saturation_factor < 0: + raise ValueError('saturation_factor ({}) is not non-negative.'.format( + saturation_factor)) + + return _blend(img, rgb_to_grayscale(img), saturation_factor) + + +def generate_gaussian_noise(img, sigma=10, gray_noise=False): + """Generate Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + noise = np.float32(np.random.randn(*img.shape[0:2])) * sigma / 255.0 + noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) + else: + noise = np.float32(np.random.randn(*img.shape)) * sigma / 255.0 + return noise + + +def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0): + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_gaussian_noise(img, sigma, gray_noise) + + +def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=\ + True, rounds=False): + noise = random_generate_gaussian_noise(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255.0 + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255.0 + return out + + +def add_jpg_compression(img, quality=90): + """Add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality (float): JPG compression quality. 0 for lowest quality, 100 for + best quality. Default: 90. + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + img = np.clip(img, 0, 1) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + _, encimg = cv2.imencode('.jpg', img * 255.0, encode_param) + img = np.float32(cv2.imdecode(encimg, 1)) / 255.0 + return img + + +def random_add_jpg_compression(img, quality_range=(90, 100)): + """Randomly add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality_range (tuple[float] | list[float]): JPG compression quality + range. 0 for lowest quality, 100 for best quality. + Default: (90, 100). + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + quality = int(np.random.uniform(quality_range[0], quality_range[1])) + return add_jpg_compression(img, quality) + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=(0.6, 5), + sigma_y_range=(0.6, 5), + rotation_range=(-math.pi, math.pi), + betag_range=(0.5, 8), + betap_range=(0.5, 8), + noise_range=None): + """Randomly generate mixed kernels. + + Args: + kernel_list (tuple): a list name of kernel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', + 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each + kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range, + isotropic=True) + elif kernel_type == 'aniso': + kernel = random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range, + isotropic=False) + return kernel + + +def random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + kernel = bivariate_Gaussian(kernel_size, + sigma_x, + sigma_y, + rotation, + isotropic=isotropic) + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=\ + kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_Gaussian(kernel_size, + sig_x, + sig_y, + theta, + grid=None, + isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + + Args: + kernel_size (int): + + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), + yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + if suffix is not None and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + elif recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=\ + recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + def __init__(self, + db_paths, + client_keys='default', + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + if isinstance(client_keys, str): + client_keys = [client_keys] + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len( + self.db_paths + ), f'client_keys and db_paths should have the same length, but received {len(client_keys)} and {len(self.db_paths)}.' + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=\ + lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, f'client_key {client_key} is not in lmdb clients.' + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError( + f'Backend {backend} is not supported. Currently supported ones are {list(self._backends.keys())}' + ) + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = { + 'color': cv2.IMREAD_COLOR, + 'grayscale': cv2.IMREAD_GRAYSCALE, + 'unchanged': cv2.IMREAD_UNCHANGED + } + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255.0 + return img + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + return img.transpose(2, 0, 1) + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): + """Paired random crop. Support Numpy array and Tensor inputs. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. Default: None. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + input_type = 'Tensor' if isinstance(img_gts[0], paddle.Tensor) else 'Numpy' + if input_type == 'Tensor': + h_lq, w_lq = img_lqs[0].size()[-2:] + h_gt, w_gt = img_gts[0].size()[-2:] + else: + h_lq, w_lq = img_lqs[0].shape[0:2] + h_gt, w_gt = img_gts[0].shape[0:2] + lq_patch_size = gt_patch_size // scale + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError( + f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError( + f'LQ ({h_lq}, {w_lq}) is smaller than patch size ({lq_patch_size}, {lq_patch_size}). Please remove {gt_path}.' + ) + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + if input_type == 'Tensor': + img_lqs = [ + v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] + for v in img_lqs + ] + else: + img_lqs = [ + v[top:top + lq_patch_size, left:left + lq_patch_size, ...] + for v in img_lqs + ] + top_gt, left_gt = int(top * scale), int(left * scale) + if input_type == 'Tensor': + img_gts = [ + v[:, :, top_gt:top_gt + gt_patch_size, + left_gt:left_gt + gt_patch_size] for v in img_gts + ] + else: + img_gts = [ + v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, + ...] for v in img_gts + ] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: + cv2.flip(img, 1, img) + if vflip: + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + elif return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + h, w = img.shape[:2] + if center is None: + center = w // 2, h // 2 + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError( + f'The img type should be np.float32 or np.uint8, but got {img_type}' + ) + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError( + f'The dst_type should be np.float32 or np.uint8, but got {dst_type}' + ) + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ( + 'The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len( + keys + ) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), ( + f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.' + gt_path = osp.join(gt_folder, gt_path) + paths.append( + dict([(f'{input_key}_path', input_path), + (f'{gt_key}_path', gt_path)])) + return paths + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ( + 'The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len( + keys + ) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError( + f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError( + f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append( + dict([(f'{input_key}_path', lmdb_key), + (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, + filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ( + 'The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len( + keys + ) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.strip().split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append( + dict([(f'{input_key}_path', input_path), + (f'{gt_key}_path', gt_path)])) + return paths diff --git a/test_tipc/configs/GFPGAN/train_infer_python.txt b/test_tipc/configs/GFPGAN/train_infer_python.txt new file mode 100644 index 0000000..30e4abe --- /dev/null +++ b/test_tipc/configs/GFPGAN/train_infer_python.txt @@ -0,0 +1,59 @@ +===========================train_params=========================== +model_name:GFPGAN +python:python3.7 +gpu_list:0 +## +auto_cast:null +total_iters:lite_train_lite_infer=10 +output_dir:./output/ +dataset.train.batch_size:lite_train_lite_infer=3 +pretrained_model:null +train_model_name:gfpgan_ffhq1024*/*checkpoint.pdparams +train_infer_img_dir:null +null:null +## +trainer:norm_train +norm_train:tools/main.py -c configs/gfpgan_ffhq1024.yaml --seed 123 -o log_config.interval=1 snapshot_config.interval=10 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +--output_dir:./output/ +load:null +norm_export:tools/export_model.py -c configs/gfpgan_ffhq1024.yaml --inputs_size="1,3,512,512" --model_name inference --load +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +inference_dir:inference +train_model:./inference/stylegan2/stylegan2model_gen +infer_export:null +infer_quant:False +inference:tools/inference.py --model_type gfpgan --seed 123 -c configs/gfpgan_ffhq1024.yaml --output_path test_tipc/output/ -o validate=None +--device:gpu +null:null +null:null +null:null +null:null +null:null +--model_path: +null:null +null:null +--benchmark:False +null:null +===========================train_benchmark_params========================== +batch_size:8 +fp_items:fp32 +epoch:100 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_cudnn_exhaustive_search=1 +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[1, 512]}, {float32,[1]}] diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 6706af8..aedeee4 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -76,6 +76,11 @@ if [ ${MODE} = "lite_train_lite_infer" ];then cd ./data/ && unzip -q singan-official_images.zip && cd ../ mkdir -p ./data/singan mv ./data/SinGAN-official_images/Images/stone.png ./data/singan ;; + GFPGAN) + rm -rf ./data/gfpgan* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/gfpgan_tipc_data.zip --no-check-certificate + mkdir -p ./data/gfpgan_data + cd ./data/ && unzip -q gfpgan_tipc_data.zip -d gfpgan_data/ && cd ../ ;; esac elif [ ${MODE} = "whole_train_whole_infer" ];then if [ ${model_name} == "Pix2pix" ]; then diff --git a/tools/inference.py b/tools/inference.py index 02a1791..0c87503 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -334,6 +334,15 @@ def main(): metric_file = os.path.join(args.output_path, "singan/metric.txt") for metric in metrics.values(): metric.update(prediction, data['A']) + elif model_type == 'gfpgan': + input_handles[0].copy_from_cpu(data['lq'].numpy()) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction) + image_numpy = tensor2img(prediction, min_max) + save_image( + image_numpy, + os.path.join(args.output_path, "gfpgan/{}.png".format(i))) elif model_type == "swinir": lq = data[1].numpy() _, _, h_old, w_old = lq.shape -- GitLab