未验证 提交 45922b0c 编写于 作者: Y yangshurong 提交者: GitHub

【论文复现】GFPGAN (#703)

* gfpgan push

* gfpgan finish

* gfpgan add init

* gfpgan del recover

* 11111

* gfpgan change name

* gfpgan recover name

* Update GFPGAN.md
Co-authored-by: Nwangna11BD <79366697+wangna11BD@users.noreply.github.com>
上级 956efd9d
total_iters: 800000
output_dir: output
find_unused_parameters: True
log_config:
interval: 100
visiual_interval: 100
snapshot_config:
interval: 30000
enable_visualdl: False
validate:
interval: 5000
save_img: True
metrics:
psnr:
name: PSNR
crop_border: 0
test_y_channel: false
fid:
name: FID
batch_size: 8
model:
name: GFPGANModel
network_g:
name: GFPGANv1
out_size: 512
num_style_feat: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
decoder_load_path: https://paddlegan.bj.bcebos.com/models/StyleGAN2_FFHQ512_Cmul1.pdparams
fix_decoder: true
num_mlp: 8
lr_mlp: 0.01
input_is_latent: true
different_w: true
narrow: 1
sft_half: true
network_d:
name: StyleGAN2DiscriminatorGFPGAN
out_size: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
network_d_left_eye:
type: FacialComponentDiscriminator
network_d_right_eye:
type: FacialComponentDiscriminator
network_d_mouth:
type: FacialComponentDiscriminator
network_identity:
name: ResNetArcFace
block: IRBlock
layers: [2, 2, 2, 2]
use_se: False
path:
image_visual: gfpgan_train_outdir
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: ~
pretrain_network_d: ~
pretrain_network_d_left_eye: https://paddlegan.bj.bcebos.com/models/Facial_component_discriminator.pdparams
pretrain_network_d_right_eye: https://paddlegan.bj.bcebos.com/models/Facial_component_discriminator.pdparams
pretrain_network_d_mouth: https://paddlegan.bj.bcebos.com/models/Facial_component_discriminator.pdparams
pretrain_network_identity: https://paddlegan.bj.bcebos.com/models/arcface_resnet18.pdparams
# losses
# pixel loss
pixel_opt:
name: GFPGANL1Loss
loss_weight: !!float 1e-1
reduction: mean
# L1 loss used in pyramid loss, component style loss and identity loss
L1_opt:
name: GFPGANL1Loss
loss_weight: 1
reduction: mean
# image pyramid loss
pyramid_loss_weight: 1
remove_pyramid_loss: 50000
# perceptual loss (content and style losses)
perceptual_opt:
name: GFPGANPerceptualLoss
layer_weights:
# before relu
"conv1_2": 0.1
"conv2_2": 0.1
"conv3_4": 1
"conv4_4": 1
"conv5_4": 1
vgg_type: vgg19
use_input_norm: true
perceptual_weight: !!float 1
style_weight: 50
range_norm: true
criterion: l1
# gan loss
gan_opt:
name: GFPGANGANLoss
gan_type: wgan_softplus
loss_weight: !!float 1e-1
# r1 regularization for discriminator
r1_reg_weight: 10
# facial component loss
gan_component_opt:
name: GFPGANGANLoss
gan_type: vanilla
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1
comp_style_weight: 200
# identity loss
identity_weight: 10
net_d_iters: 1
net_d_init_iters: 0
net_d_reg_every: 16
export_model:
- { name: "net_g_ema", inputs_num: 1 }
dataset:
train:
name: FFHQDegradationDataset
dataroot_gt: data/gfpgan_data/train
io_backend:
type: disk
use_hflip: true
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512
blur_kernel_size: 41
kernel_list: ["iso", "aniso"]
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]
# color jitter and gray
color_jitter_prob: 0.3
color_jitter_shift: 20
color_jitter_pt_prob: 0.3
gray_prob: 0.01
# If you do not want colorization, please set
# color_jitter_prob: ~
# color_jitter_pt_prob: ~
# gray_prob: 0.01
# gt_gray: True
crop_components: true
component_path: https://paddlegan.bj.bcebos.com/models/FFHQ_eye_mouth_landmarks_512.pdparams
eye_enlarge_ratio: 1.4
# data loader
use_shuffle: true
num_workers: 4
batch_size: 1
prefetch_mode: ~
test:
# Please modify accordingly to use your own validation
# Or comment the val block if do not need validation during training
name: PairedImageDataset
dataroot_lq: data/gfpgan_data/lq
dataroot_gt: data/gfpgan_data/gt
io_backend:
type: disk
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
scale: 1
num_workers: 4
batch_size: 8
phase: val
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.002
milestones: [600000, 700000]
gamma: 0.5
optimizer:
optim_g:
name: Adam
beta1: 0
beta2: 0.99
optim_d:
name: Adam
beta1: 0
beta2: 0.99
optim_component:
name: Adam
beta1: 0.9
beta2: 0.99
## GFPGAN Blind Face Restoration Model
## 1、Introduction
GFP-GAN that leverages rich and diverse priors encapsulated in a pretrained face GAN for blind face restoration.
### Overview of GFP-GAN framework:
![image](https://user-images.githubusercontent.com/73787862/191736718-72f5aa09-d7a9-490b-b1f8-b609208d4654.png)
GFP-GAN is comprised of a degradation removal
module (U-Net) and a pretrained face GAN (such as StyleGAN2) as prior. They are bridged by a latent code
mapping and several Channel-Split Spatial Feature Transform (CS-SFT) layers.
By dealing with features, it achieving realistic results while preserving high fidelity.
For a more detailed introduction to the model, and refer to the repo, you can view the following AI Studio project
[https://aistudio.baidu.com/aistudio/projectdetail/4421649](https://aistudio.baidu.com/aistudio/projectdetail/4421649)
In this experiment, We train
our model with Adam optimizer for a total of 210k iterations.
The result of experiments of recovering of GFPGAN as following:
Model | LPIPS | FID | PSNR
--- |:---:|:---:|:---:|
GFPGAN | 0.3817 | 36.8068 | 65.0461
## 2、Ready to work
### 2.1 Dataset Preparation
The GFPGAN model training set is the classic FFHQ face data set,
with a total of 70,000 high-resolution 1024 x 1024 high-resolution face pictures,
and the test set is the CELEBA-HQ data set, with a total of 2,000 high-resolution face pictures. The generation way is the same as that during training.
For details, please refer to **Dataset URL:** [FFHQ](https://github.com/NVlabs/ffhq-dataset), [CELEBA-HQ](https://github.com/tkarras/progressive_growing_of_gans).
The specific download links are given below:
**Original dataset download address:**
**FFHQ :** https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open
**CELEBA-HQ:** https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs?resourcekey=0-arAVTUfW9KRhN-irJchVKQ&usp=sharing
The structure of data as following
```
|-- data/GFPGAN
|-- train
|-- 00000.png
|-- 00001.png
|-- ......
|-- 00999.png
|-- ......
|-- 69999.png
|-- lq
|-- 2000张jpg图片
|-- gt
|-- 2000张jpg图片
```
Please modify the dataroot parameters of dataset train and test in the configs/gfpgan_ffhq1024.yaml configuration file to your training set and test set path.
### 2.2 Model preparation
**Model parameter file and training log download address:**
https://paddlegan.bj.bcebos.com/models/GFPGAN.pdparams
Download the model parameters and test images from the link and put them in the data/ folder in the project root directory. The specific file structure is as follows:
the params is a dict(one type in python),and could be load by paddlepaddle. It contains key (net_g,net_g_ema),you can use any of one to inference
## 3、Start using
### 3.1 model training
Enter the following code in the console to start training:
```bash
python tools/main.py -c configs/gfpgan_ffhq1024.yaml
```
The model supports single-card training and multi-card training.So you can use this bash to train
```bash
!CUDA_VISIBLE_DEVICES=0,1,2,3
!python -m paddle.distributed.launch tools/main.py \
--config-file configs/gpfgan_ffhq1024.yaml
```
Model training needs to use paddle2.3 and above, and wait for paddle to implement the second-order operator related functions of elementwise_pow. The paddle2.2.2 version can run normally, but the model cannot be successfully trained because some loss functions will calculate the wrong gradient. . If an error is reported during training, training is not supported for the time being. You can skip the training part and directly use the provided model parameters for testing. Model evaluation and testing can use paddle2.2.2 and above.
### 3.2 Model evaluation
When evaluating the model, enter the following code in the console, using the downloaded model parameters mentioned above:
```shell
python tools/main.py -c configs/gfpgan_ffhq1024.yaml --load GFPGAN.pdparams --evaluate-only
```
If you want to test on your own provided model, please modify the path after --load .
### 3.3 Model prediction
#### 3.3.1 Export model
After training, you need to use ``tools/export_model.py`` to extract the weights of the generator from the trained model (including the generator only)
Enter the following command to extract the model of the generator:
```bash
python -u tools/export_model.py --config-file configs/gfpgan_ffhq1024.yaml \
--load GFPGAN.pdparams \
--inputs_size 1,3,512,512
```
#### 3.3.2 Process a single image
You can use our tools in ppgan/faceutils/face_enhancement/gfpgan_enhance.py to inferences one picture quickly
```python
%env PYTHONPATH=.:$PYTHONPATH
%env CUDA_VISIBLE_DEVICES=0
import paddle
import cv2
import numpy as np
import sys
from ppgan.faceutils.face_enhancement.gfpgan_enhance import gfp_FaceEnhancement
# you can use your path
img_path='test/2.png'
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
# this is origin picture
cv2.imwrite('test/outlq.png',img)
img=np.array(img).astype('float32')
faceenhancer = gfp_FaceEnhancement()
img = faceenhancer.enhance_from_image(img)
# the result of prediction
cv2.imwrite('test/out_gfpgan.png',img)
```
![image](https://user-images.githubusercontent.com/73787862/191741112-b813a02c-6b19-4591-b80d-0bf5ce8ad07e.png)
![image](https://user-images.githubusercontent.com/73787862/191741242-1f365048-ba25-450f-8abc-76e74d8786f8.png)
## 4. Tipc
### 4.1 Export the inference model
```bash
python -u tools/export_model.py --config-file configs/gfpgan_ffhq1024.yaml \
--load GFPGAN.pdparams \
--inputs_size 1,3,512,512
```
You can also modify the parameters after --load to the model parameter file you want to test.
### 4.2 Inference with a prediction engine
```bash
%cd /home/aistudio/work/PaddleGAN
# %env PYTHONPATH=.:$PYTHONPATH
# %env CUDA_VISIBLE_DEVICES=0
!python -u tools/inference.py --config-file configs/gfpgan_ffhq1024.yaml \
--model_path GFPGAN.pdparams \
--model_type gfpgan \
--device gpu \
-o validate=None
```
### 4.3 Call the script to complete the training and push test in two steps
To invoke the `lite_train_lite_infer` mode of the foot test base training prediction function, run:
```bash
%cd /home/aistudio/work/PaddleGAN
!bash test_tipc/prepare.sh \
test_tipc/configs/GFPGAN/train_infer_python.txt \
lite_train_lite_infer
!bash test_tipc/test_train_inference_python.sh \
test_tipc/configs/GFPGAN/train_infer_python.txt \
lite_train_lite_infer
```
## 5、References
```
@InProceedings{wang2021gfpgan,
author = {Xintao Wang and Yu Li and Honglun Zhang and Ying Shan},
title = {Towards Real-World Blind Face Restoration with Generative Facial Prior},
booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2021}
}
```
## GFPGAN 盲脸复原模型
## 1、介绍
GFP-GAN利用丰富和多样化的先验封装在预先训练的面部GAN用于盲人面部恢复。
### GFPGAN的整体结构:
![image](https://user-images.githubusercontent.com/73787862/191736718-72f5aa09-d7a9-490b-b1f8-b609208d4654.png)
GFP-GAN由降解去除物组成
模块(U-Net)和预先训练的面部GAN(如StyleGAN2)作为先验。他们之间有隐藏的密码
映射和几个通道分割空间特征变换(CS-SFT)层。
通过处理特征,它在保持高保真度的同时实现了真实的结果。
要了解更详细的模型介绍,并参考回购,您可以查看以下AI Studio项目
[基于PaddleGAN复现GFPGAN](https://aistudio.baidu.com/aistudio/projectdetail/4421649)
在这个实验中,我们训练
我们的模型和Adam优化器共进行了210k次迭代。
GFPGAN的回收实验结果如下:
Model | LPIPS | FID | PSNR
--- |:---:|:---:|:---:|
GFPGAN | 0.3817 | 36.8068 | 65.0461
## 2、准备工作
### 2.1 数据集准备
GFPGAN模型训练集是经典的FFHQ人脸数据集,
总共有7万张高分辨率1024 x 1024的人脸图片,
测试集为CELEBA-HQ数据集,共有2000张高分辨率人脸图片。生成方式与训练时相同。
For details, please refer to **Dataset URL:** [FFHQ](https://github.com/NVlabs/ffhq-dataset), [CELEBA-HQ](https://github.com/tkarras/progressive_growing_of_gans).
The specific download links are given below:
**原始数据集地址:**
**FFHQ :** https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open
**CELEBA-HQ:** https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs?resourcekey=0-arAVTUfW9KRhN-irJchVKQ&usp=sharing
数据集结构如下
```
|-- data/GFPGAN
|-- train
|-- 00000.png
|-- 00001.png
|-- ......
|-- 00999.png
|-- ......
|-- 69999.png
|-- lq
|-- 2000张jpg图片
|-- gt
|-- 2000张jpg图片
```
请在configs/gfpgan_ffhq1024. data中修改数据集train和test的dataroot参数。Yaml配置文件到您的训练集和测试集路径。
### 2.2 模型准备
**模型参数文件和训练日志下载地址:**
https://paddlegan.bj.bcebos.com/models/GFPGAN.pdparams
从链接下载模型参数和测试图像,并将它们放在项目根目录中的data/文件夹中。具体文件结构如下:
params是一个dict(python中的一种类型),可以通过paddlepaddle加载。它包含key (net_g,net_g_ema),您可以使用其中任何一个来进行推断
## 3、开始使用
模型训练
在控制台中输入以下代码开始训练:
```bash
python tools/main.py -c configs/gfpgan_ffhq1024.yaml
```
该模型支持单卡训练和多卡训练。
也可以使用如下命令进行多卡训练
```bash
!CUDA_VISIBLE_DEVICES=0,1,2,3
!python -m paddle.distributed.launch tools/main.py \
--config-file configs/gpfgan_ffhq1024.yaml
```
模型训练需要使用paddle2.3及以上版本,等待paddle实现elementwise_pow的二阶算子相关函数。paddle2.2.2版本可以正常运行,但由于某些损失函数会计算出错误的梯度,无法成功训练模型。如果在培训过程中报错,则暂时不支持培训。您可以跳过训练部分,直接使用提供的模型参数进行测试。模型评估和测试可以使用paddle2.2.2及以上版本。
### 3.2 模型评估
当评估模型时,在控制台中输入以下代码,使用上面提到的下载的模型参数:
```shell
python tools/main.py -c configs/gfpgan_ffhq1024.yaml --load GFPGAN.pdparams --evaluate-only
```
当评估模型时,在控制台中输入以下代码,使用下载的模型。如果您想在您自己提供的模型上进行测试,请修改之后的路径 --load .
### 3.3 模型预测
#### 3.3.1 导出模型
在训练之后,您需要使用' ' tools/export_model.py ' '从训练的模型中提取生成器的权重(仅包括生成器)
输入以下命令提取生成器的模型:
```bash
python -u tools/export_model.py --config-file configs/gfpgan_ffhq1024.yaml \
--load GFPGAN.pdparams \
--inputs_size 1,3,512,512
```
#### 3.3.2 加载一张图片
你可以使用我们在ppgan/faceutils/face_enhancement/gfpgan_enhance.py中的工具来快速推断一张图片
```python
%env PYTHONPATH=.:$PYTHONPATH
%env CUDA_VISIBLE_DEVICES=0
import paddle
import cv2
import numpy as np
import sys
from ppgan.faceutils.face_enhancement.gfpgan_enhance import gfp_FaceEnhancement
# 图片路径可以用自己的
img_path='test/2.png'
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
# 这是原来的模糊图片
cv2.imwrite('test/outlq.png',img)
img=np.array(img).astype('float32')
faceenhancer = gfp_FaceEnhancement()
img = faceenhancer.enhance_from_image(img)
# 这是生成的清晰图片
cv2.imwrite('test/out_gfpgan.png',img)
```
![image](https://user-images.githubusercontent.com/73787862/191741112-b813a02c-6b19-4591-b80d-0bf5ce8ad07e.png)
![image](https://user-images.githubusercontent.com/73787862/191741242-1f365048-ba25-450f-8abc-76e74d8786f8.png)
## 4. Tipc
### 4.1 导出推理模型
```bash
python -u tools/export_model.py --config-file configs/gfpgan_ffhq1024.yaml \
--load GFPGAN.pdparams \
--inputs_size 1,3,512,512
```
### 4.2 使用paddleInference推理
```bash
%cd /home/aistudio/work/PaddleGAN
# %env PYTHONPATH=.:$PYTHONPATH
# %env CUDA_VISIBLE_DEVICES=0
!python -u tools/inference.py --config-file configs/gfpgan_ffhq1024.yaml \
--model_path GFPGAN.pdparams \
--model_type gfpgan \
--device gpu \
-o validate=None
```
### 4.3 一键TIPC
调用足部测试基础训练预测函数的' lite_train_lite_infer '模式,执行:
```bash
%cd /home/aistudio/work/PaddleGAN
!bash test_tipc/prepare.sh \
test_tipc/configs/GFPGAN/train_infer_python.txt \
lite_train_lite_infer
!bash test_tipc/test_train_inference_python.sh \
test_tipc/configs/GFPGAN/train_infer_python.txt \
lite_train_lite_infer
```
## 5、References
```
@InProceedings{wang2021gfpgan,
author = {Xintao Wang and Yu Li and Honglun Zhang and Ying Shan},
title = {Towards Real-World Blind Face Restoration with Generative Facial Prior},
booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2021}
}
```
......@@ -32,4 +32,6 @@ from .photopen_dataset import PhotoPenDataset
from .empty_dataset import EmptyDataset
from .gpen_dataset import GPENDataset
from .swinir_dataset import SwinIRDataset
from .gfpgan_datasets import FFHQDegradationDataset
from .paired_image_datasets import PairedImageDataset
from .invdn_dataset import InvDNDataset
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import math
import numpy as np
import random
import os
import paddle
import paddle.nn.functional as F
from paddle.vision.transforms.functional import normalize
from .builder import DATASETS
from ppgan.utils.download import get_path_from_url
from ppgan.utils.gfpgan_tools import *
@DATASETS.register()
class FFHQDegradationDataset(paddle.io.Dataset):
"""FFHQ dataset for GFPGAN.
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
io_backend (dict): IO backend type and other kwarg.
mean (list | tuple): Image mean.
std (list | tuple): Image std.
use_hflip (bool): Whether to horizontally flip.
Please see more options in the codes.
"""
def __init__(self, **opt):
super(FFHQDegradationDataset, self).__init__()
self.opt = opt
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.mean = opt['mean']
self.std = opt['std']
self.out_size = opt['out_size']
self.crop_components = opt.get('crop_components', False)
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
if self.crop_components:
self.components_list = get_path_from_url(opt.get('component_path'))
self.components_list = paddle.load(self.components_list)
# print(self.components_list)
self.paths = paths_from_folder(self.gt_folder)
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
self.blur_sigma = opt['blur_sigma']
self.downsample_range = opt['downsample_range']
self.noise_range = opt['noise_range']
self.jpeg_range = opt['jpeg_range']
self.color_jitter_prob = opt.get('color_jitter_prob')
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
self.gray_prob = opt.get('gray_prob')
self.color_jitter_shift /= 255.0
@staticmethod
def color_jitter(img, shift):
"""jitter color: randomly jitter the RGB values, in numpy formats"""
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
return img
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
fn_idx = paddle.randperm(4)
img = paddle.to_tensor(img, dtype=img.dtype)
for fn_id in fn_idx:
# print('fn_id',fn_id)
if fn_id == 0 and brightness is not None:
brightness_factor = paddle.to_tensor(1.0).uniform_(
brightness[0], brightness[1]).item()
# print("brightness_factor",brightness_factor)
img = adjust_brightness(img, brightness_factor)
if fn_id == 1 and contrast is not None:
contrast_factor = paddle.to_tensor(1.0).uniform_(
contrast[0], contrast[1]).item()
img = adjust_contrast(img, contrast_factor)
if fn_id == 2 and saturation is not None:
saturation_factor = paddle.to_tensor(1.0).uniform_(
saturation[0], saturation[1]).item()
img = adjust_saturation(img, saturation_factor)
if fn_id == 3 and hue is not None:
hue_factor = paddle.to_tensor(1.0).uniform_(hue[0],
hue[1]).item()
img = adjust_hue(img, hue_factor)
return img
def get_component_coordinates(self, index, status):
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
# print(f'{index:08d}',type(self.components_list))
components_bbox = self.components_list[f'{index:08d}']
if status[0]:
tmp = components_bbox['left_eye']
components_bbox['left_eye'] = components_bbox['right_eye']
components_bbox['right_eye'] = tmp
components_bbox['left_eye'][
0] = self.out_size - components_bbox['left_eye'][0]
components_bbox['right_eye'][
0] = self.out_size - components_bbox['right_eye'][0]
components_bbox['mouth'][
0] = self.out_size - components_bbox['mouth'][0]
locations = []
for part in ['left_eye', 'right_eye', 'mouth']:
mean = components_bbox[part][0:2]
half_len = components_bbox[part][2]
if 'eye' in part:
half_len *= self.eye_enlarge_ratio
loc = np.hstack((mean - half_len + 1, mean + half_len))
loc = paddle.to_tensor(loc)
locations.append(loc)
return locations
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'),
**self.io_backend_opt)
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float32=True)
img_gt = cv2.resize(img_gt, (self.out_size, self.out_size))
img_gt, status = augment(img_gt,
hflip=self.opt['use_hflip'],
rotation=False,
return_status=True)
h, w, _ = img_gt.shape
if self.crop_components:
locations = self.get_component_coordinates(index, status)
loc_left_eye, loc_right_eye, loc_mouth = locations
kernel = random_mixed_kernels(self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
noise_range=None)
img_lq = cv2.filter2D(img_gt, -1, kernel)
scale = np.random.uniform(self.downsample_range[0],
self.downsample_range[1])
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)),
interpolation=cv2.INTER_LINEAR)
if self.noise_range is not None:
img_lq = random_add_gaussian_noise(img_lq, self.noise_range)
if self.jpeg_range is not None:
img_lq = random_add_jpg_compression(img_lq, self.jpeg_range)
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
if self.color_jitter_prob is not None and np.random.uniform(
) < self.color_jitter_prob:
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
if self.gray_prob and np.random.uniform() < self.gray_prob:
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
if self.opt.get('gt_gray'):
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)
if self.color_jitter_pt_prob is not None and np.random.uniform(
) < self.color_jitter_pt_prob:
brightness = self.opt.get('brightness', (0.5, 1.5))
contrast = self.opt.get('contrast', (0.5, 1.5))
saturation = self.opt.get('saturation', (0, 1.5))
hue = self.opt.get('hue', (-0.1, 0.1))
img_lq = self.color_jitter_pt(img_lq, brightness, contrast,
saturation, hue)
img_lq = np.clip((img_lq * 255.0).round(), 0, 255) / 255.0
img_gt = normalize(img_gt, self.mean, self.std)
img_lq = normalize(img_lq, self.mean, self.std)
if self.crop_components:
return_dict = {
'lq': img_lq,
'gt': img_gt,
'gt_path': gt_path,
'loc_left_eye': loc_left_eye,
'loc_right_eye': loc_right_eye,
'loc_mouth': loc_mouth
}
return return_dict
else:
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.vision.transforms.functional import normalize
from .builder import DATASETS
from ppgan.utils.gfpgan_tools import *
@DATASETS.register()
class PairedImageDataset(paddle.io.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
There are three modes:
1. 'lmdb': Use lmdb files.
If opt['io_backend'] == lmdb.
2. 'meta_info_file': Use meta information file to generate paths.
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. 'folder': Scan folders to generate paths.
The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, **opt):
super(PairedImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
if 'filename_tmpl' in opt:
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb(
[self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info_file' in self.opt and self.opt[
'meta_info_file'] is not None:
self.paths = paired_paths_from_meta_info_file(
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.opt['meta_info_file'], self.filename_tmpl)
else:
self.paths = paired_paths_from_folder(
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.filename_tmpl)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'),
**self.io_backend_opt)
# print(self.file_client)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'],
self.opt['use_rot'])
# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if self.opt['phase'] != 'train':
img_gt = img_gt[0:img_lq.shape[0] * scale,
0:img_lq.shape[1] * scale, :]
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)
# normalize
if self.mean is not None or self.std is not None:
img_lq = normalize(img_lq, self.mean, self.std)
img_gt = normalize(img_gt, self.mean, self.std)
return {
'lq': img_lq,
'gt': img_gt,
'lq_path': lq_path,
'gt_path': gt_path
}
def __len__(self):
return len(self.paths)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import sys
import paddle
import paddle.nn as nn
from ppgan.utils.visual import *
from ppgan.utils.download import get_path_from_url
from ppgan.models.generators import GFPGANv1Clean
from ppgan.models.generators import GFPGANv1
from ppgan.faceutils.face_detection.detection.blazeface.utils import *
GFPGAN_weights = 'https://paddlegan.bj.bcebos.com/models/GFPGAN.pdparams'
class gfp_FaceEnhancement(object):
def __init__(self, size=512, batch_size=1):
super(gfp_FaceEnhancement, self).__init__()
# Initialise the face detector
model_weights_path = get_path_from_url(GFPGAN_weights)
model_weights = paddle.load(model_weights_path)
self.face_enhance = GFPGANv1(out_size=512,
num_style_feat=512,
channel_multiplier=1,
resample_kernel=[1, 3, 3, 1],
decoder_load_path=None,
fix_decoder=True,
num_mlp=8,
lr_mlp=0.01,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
self.face_enhance.load_dict(model_weights['net_g_ema'])
self.face_enhance.eval()
self.size = size
self.mask = np.zeros((512, 512), np.float32)
cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1,
cv2.LINE_AA)
self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11)
self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11)
self.mask = paddle.tile(paddle.to_tensor(
self.mask).unsqueeze(0).unsqueeze(-1),
repeat_times=[batch_size, 1, 1, 3]).numpy()
def enhance_from_image(self, img):
if isinstance(img, np.ndarray):
img, _ = resize_and_crop_image(img, 512)
img = paddle.to_tensor(img).transpose([2, 0, 1])
else:
assert img.shape == [3, 512, 512]
return self.enhance_from_batch(img.unsqueeze(0))[0]
def enhance_from_batch(self, img):
if isinstance(img, np.ndarray):
img_ori, _ = resize_and_crop_batch(img, 512)
img = paddle.to_tensor(img_ori).transpose([0, 3, 1, 2])
else:
assert img.shape[1:] == [3, 512, 512]
img_ori = img.transpose([0, 2, 3, 1]).numpy()
img_t = (img / 255. - 0.5) / 0.5
with paddle.no_grad():
out, __ = self.face_enhance(img_t)
image_tensor = out * 0.5 + 0.5
image_tensor = image_tensor.transpose([0, 2, 3, 1]) # RGB
image_numpy = paddle.clip(image_tensor, 0, 1) * 255.0
out = image_numpy.astype(np.uint8).cpu().numpy()
return out * self.mask + (1 - self.mask) * img_ori
......@@ -39,4 +39,5 @@ from .rcan_model import RCANModel
from .prenet_model import PReNetModel
from .gpen_model import GPENModel
from .swinir_model import SwinIRModel
from .gfpgan_model import GFPGANModel
from .invdn_model import InvDNModel
......@@ -10,3 +10,4 @@ from .builder import build_criterion
from .ssim import SSIM
from .id_loss import IDLoss
from .gfpgan_loss import GFPGANGANLoss, GFPGANL1Loss, GFPGANPerceptualLoss
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import math
import numpy as np
from collections import OrderedDict
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import vgg
from .builder import CRITERIONS
from ppgan.utils.download import get_path_from_url
VGG_PRETRAIN_PATH = os.path.join(os.getcwd(), 'pretrain', 'vgg19' + '.pdparams')
NAMES = {
'vgg11': [
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1',
'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1',
'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2',
'relu5_2', 'pool5'
],
'vgg13': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
],
'vgg16': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1',
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
'pool5'
],
'vgg19': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4',
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4',
'pool5'
]
}
def insert_bn(names):
"""Insert bn layer after each conv.
Args:
names (list): The list of layer names.
Returns:
list: The list of layer names with bn layers.
"""
names_bn = []
for name in names:
names_bn.append(name)
if 'conv' in name:
position = name.replace('conv', '')
names_bn.append('bn' + position)
return names_bn
class VGGFeatureExtractor(nn.Layer):
"""VGG network for feature extraction.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): Forward function returns the corresponding
features according to the layer_name_list.
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image. Importantly,
the input feature must in the range [0, 1]. Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
requires_grad (bool): If true, the parameters of VGG network will be
optimized. Default: False.
remove_pooling (bool): If true, the max pooling operations in VGG net
will be removed. Default: False.
pooling_stride (int): The stride of max pooling operation. Default: 2.
"""
def __init__(
self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
requires_grad=False,
remove_pooling=False,
pooling_stride=2,
pretrained_url='https://paddlegan.bj.bcebos.com/models/vgg19.pdparams'
):
super(VGGFeatureExtractor, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.range_norm = range_norm
self.names = NAMES[vgg_type.replace('_bn', '')]
if 'bn' in vgg_type:
self.names = insert_bn(self.names)
max_idx = 0
for v in layer_name_list:
idx = self.names.index(v)
if idx > max_idx:
max_idx = idx
if os.path.exists(VGG_PRETRAIN_PATH):
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
weight_path = get_path_from_url(pretrained_url)
state_dict = paddle.load(weight_path)
vgg_net.set_state_dict(state_dict)
else:
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
features = vgg_net.features[:max_idx + 1]
self.vgg_layers = nn.Sequential()
for k, v in zip(self.names, features):
if 'pool' in k:
if remove_pooling:
continue
else:
self.vgg_layers.add_sublayer(
k, nn.MaxPool2D(kernel_size=2, stride=pooling_stride))
else:
self.vgg_layers.add_sublayer(k, v)
if not requires_grad:
self.vgg_layers.eval()
for param in self.parameters():
param.stop_gradient = True
else:
self.vgg_layers.train()
for param in self.parameters():
param.stop_gradient = False
if self.use_input_norm:
self.register_buffer(
'mean',
paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]))
self.register_buffer(
'std',
paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]))
def forward(self, x, rep=None):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for name, module in self.vgg_layers.named_children():
x = module(x)
if name in self.layer_name_list:
output[name] = x.clone()
return output
@CRITERIONS.register()
class GFPGANPerceptualLoss(nn.Layer):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 0.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
perceptual_weight=1.0,
style_weight=0.0,
criterion='l1'):
super(GFPGANPerceptualLoss, self).__init__()
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.vgg = VGGFeatureExtractor(layer_name_list=list(
layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
range_norm=range_norm)
self.criterion_type = criterion
if self.criterion_type == 'l1':
self.criterion = paddle.nn.L1Loss()
elif self.criterion_type == 'fro':
self.criterion = None
else:
raise NotImplementedError(
f'{criterion} criterion has not been supported.')
def forward(self, x, gt, rep=None):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x_features = self.vgg(x, rep)
gt_features = self.vgg(gt.detach())
if self.perceptual_weight > 0:
percep_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
percep_loss += paddle.linalg.norm(
x_features[k] - gt_features[k],
p='fro') * self.layer_weights[k]
else:
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss *= self.perceptual_weight
else:
percep_loss = None
if self.style_weight > 0:
style_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
style_loss += paddle.linalg.norm(
self._gram_mat(x_features[k]) -
self._gram_mat(gt_features[k]),
p='fro') * self.layer_weights[k]
else:
style_loss += self.criterion(
self._gram_mat(x_features[k]),
self._gram_mat(gt_features[k])) * self.layer_weights[k]
style_loss *= self.style_weight
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
(n, c, h, w) = x.shape
features = x.reshape([n, c, w * h])
features_t = features.transpose([0, 2, 1])
gram = features.bmm(features_t) / (c * h * w)
return gram
@CRITERIONS.register()
class GFPGANGANLoss(nn.Layer):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self,
gan_type,
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=1.0):
super(GFPGANGANLoss, self).__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan_softplus':
self.loss = self._wgan_softplus_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(
f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
def _wgan_softplus_loss(self, input, target):
"""wgan loss with soft plus. softplus is a smooth approximation to the
ReLU function.
In StyleGAN2, it is called:
Logistic loss for discriminator;
Non-saturating loss for generator.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-1.0 *
input).mean() if target else F.softplus(input).mean()
def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan_softplus']:
return target_is_real
target_val = (self.real_label_val
if target_is_real else self.fake_label_val)
return paddle.ones(input.shape, dtype=input.dtype) * target_val
def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight
@CRITERIONS.register()
class GFPGANL1Loss(nn.Layer):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(GFPGANL1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: "none" | "mean" | "sum"'
)
self.loss_weight = loss_weight
self.l1_loss = paddle.nn.L1Loss(reduction)
def forward(self, pred, target):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * self.l1_loss(pred, target)
......@@ -25,3 +25,4 @@ from .discriminator_firstorder import FirstOrderDiscriminator
from .discriminator_lapstyle import LapStyleDiscriminator
from .discriminator_photopen import MultiscaleDiscriminator
from .discriminator_singan import SinGANDiscriminator
from .arcface_arch_paddle import ResNetArcFace
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import DISCRIMINATORS
def conv3x3(inplanes, outplanes, stride=1):
"""A simple wrapper for 3x3 convolution with padding.
Args:
inplanes (int): Channel number of inputs.
outplanes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
"""
return nn.Conv2D(inplanes,
outplanes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
class BasicBlock(nn.Layer):
"""Basic residual block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2D(planes)
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2D(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class IRBlock(nn.Layer):
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
use_se=True):
super(IRBlock, self).__init__()
self.bn0 = nn.BatchNorm2D(inplanes)
self.conv1 = conv3x3(inplanes, inplanes)
self.bn1 = nn.BatchNorm2D(inplanes)
self.prelu = PReLU_layer()
self.conv2 = conv3x3(inplanes, planes, stride)
self.bn2 = nn.BatchNorm2D(planes)
self.downsample = downsample
self.stride = stride
self.use_se = use_se
if self.use_se:
self.se = SEBlock(planes)
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.prelu(out)
return out
class Bottleneck(nn.Layer):
"""Bottleneck block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes)
self.conv2 = nn.Conv2D(planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes)
self.conv3 = nn.Conv2D(planes,
planes * self.expansion,
kernel_size=1,
bias_attr=False)
self.bn3 = nn.BatchNorm2D(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class PReLU_layer(nn.Layer):
def __init__(self, init_value=0.25, num=1):
super(PReLU_layer, self).__init__()
x = self.create_parameter(
attr=None,
shape=[num],
dtype=paddle.get_default_dtype(),
is_bias=False,
default_initializer=nn.initializer.Constant(init_value))
self.add_parameter('weight', x)
def forward(self, x):
return F.prelu(x, self.weight)
class SEBlock(nn.Layer):
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
Args:
channel (int): Channel number of inputs.
reduction (int): Channel reduction ration. Default: 16.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction),
nn.PReLU(),
nn.Linear(channel // reduction, channel),
nn.Sigmoid())
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
def constant_init(param, **kwargs):
initializer = nn.initializer.Constant(**kwargs)
initializer(param, param.block)
@DISCRIMINATORS.register()
class ResNetArcFace(nn.Layer):
"""ArcFace with ResNet architectures.
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
Args:
block (str): Block used in the ArcFace architecture.
layers (tuple(int)): Block numbers in each layer.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
def __init__(self, block, layers, use_se=True, reprod_logger=None):
if block == 'IRBlock':
block = IRBlock
self.inplanes = 64
self.use_se = use_se
super(ResNetArcFace, self).__init__()
self.conv1 = nn.Conv2D(1, 64, kernel_size=3, padding=1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(64)
self.maxpool = nn.MaxPool2D(kernel_size=2, stride=2)
self.prelu = PReLU_layer()
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.bn4 = nn.BatchNorm2D(512)
self.dropout = nn.Dropout()
self.fc5 = nn.Linear(512 * 8 * 8, 512)
self.bn5 = nn.BatchNorm1D(512)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, paddle.nn.Conv2D):
nn.initializer.XavierNormal(m.weight)
elif isinstance(m, paddle.nn.BatchNorm2D) or isinstance(
m, paddle.nn.BatchNorm1D):
constant_init(m.weight, value=1.)
constant_init(m.bias, value=0.)
elif isinstance(m, paddle.nn.Linear):
nn.initializer.XavierNormal(m.weight)
constant_init(m.bias, value=0.)
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion))
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample,
use_se=self.use_se))
self.inplanes = planes
for _ in range(1, num_blocks):
layers.append(block(self.inplanes, planes, use_se=self.use_se))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn4(x)
x = self.dropout(x)
x = x.reshape([x.shape[0], -1])
x = self.fc5(x)
x = self.bn5(x)
return x
......@@ -43,4 +43,6 @@ from .rcan import RCAN
from .prenet import PReNet
from .gpen import GPEN
from .swinir import SwinIR
from .gfpganv1_clean_arch import GFPGANv1Clean
from .gfpganv1_arch import GFPGANv1, StyleGAN2DiscriminatorGFPGAN
from .invdn import InvDN
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import paddle
from paddle import nn
from paddle.nn import functional as F
from ppgan.models.generators.stylegan2_clean_arch import StyleGAN2GeneratorClean
from ppgan.models.generators.builder import GENERATORS
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
narrow=1,
sft_half=False):
super(StyleGAN2GeneratorCSFT,
self).__init__(out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow)
self.sft_half = sft_half
def forward(self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2GeneratorCSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f'noise{i}')
for i in range(self.num_layers)
]
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_truncation
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
latent = paddle.tile(styles[0].unsqueeze(1),
repeat_times=[1, inject_index, 1])
else:
latent = styles[0]
elif len(styles) == 2:
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = paddle.tile(styles[0].unsqueeze(1),
repeat_times=[1, inject_index, 1])
latent2 = paddle.tile(
styles[1].unsqueeze(1),
repeat_times=[1, self.num_latent - inject_index, 1])
latent = paddle.concat([latent1, latent2], axis=1)
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2],
self.style_convs[1::2],
noise[1::2],
noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
if i < len(conditions):
if self.sft_half:
out_same, out_sft = paddle.split(out, 2, axis=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = paddle.concat([out_same, out_sft], axis=1)
else:
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ResBlock(nn.Layer):
"""Residual block with bilinear upsampling/downsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
"""
def __init__(self, in_channels, out_channels, mode='down'):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2D(in_channels, in_channels, 3, 1, 1)
self.conv2 = nn.Conv2D(in_channels, out_channels, 3, 1, 1)
self.skip = nn.Conv2D(in_channels, out_channels, 1, bias_attr=False)
if mode == 'down':
self.scale_factor = 0.5
elif mode == 'up':
self.scale_factor = 2
def forward(self, x):
out = paddle.nn.functional.leaky_relu(self.conv1(x), negative_slope=0.2)
out = F.interpolate(out, scale_factor=self.scale_factor, mode=\
'bilinear', align_corners=False)
out = paddle.nn.functional.leaky_relu(self.conv2(out),
negative_slope=0.2)
x = F.interpolate(x, scale_factor=self.scale_factor, mode=\
'bilinear', align_corners=False)
skip = self.skip(x)
out = out + skip
return out
def debug(x):
print(type(x))
if isinstance(x, list):
for i, v in enumerate(x):
print(i, v.shape)
else:
print(0, x.shape)
@GENERATORS.register()
class GFPGANv1Clean(nn.Layer):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False):
super(GFPGANv1Clean, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5
print("unet_narrow", unet_narrow, "channel_multiplier",
channel_multiplier)
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
'16': int(512 * unet_narrow),
'32': int(512 * unet_narrow),
'64': int(256 * channel_multiplier * unet_narrow),
'128': int(128 * channel_multiplier * unet_narrow),
'256': int(64 * channel_multiplier * unet_narrow),
'512': int(32 * channel_multiplier * unet_narrow),
'1024': int(16 * channel_multiplier * unet_narrow)
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2**int(math.log(out_size, 2))
self.conv_body_first = nn.Conv2D(3, channels[f'{first_out_size}'], 1)
in_channels = channels[f'{first_out_size}']
self.conv_body_down = nn.LayerList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f'{2 ** (i - 1)}']
self.conv_body_down.append(
ResBlock(in_channels, out_channels, mode='down'))
in_channels = out_channels
self.final_conv = nn.Conv2D(in_channels, channels['4'], 3, 1, 1)
in_channels = channels['4']
self.conv_body_up = nn.LayerList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2 ** i}']
self.conv_body_up.append(
ResBlock(in_channels, out_channels, mode='up'))
in_channels = out_channels
self.toRGB = nn.LayerList()
for i in range(3, self.log_size + 1):
self.toRGB.append(nn.Conv2D(channels[f'{2 ** i}'], 3, 1))
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 -
2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
self.stylegan_decoder = StyleGAN2GeneratorCSFT(out_size=out_size,
num_style_feat=num_style_feat, num_mlp=num_mlp,
channel_multiplier=channel_multiplier, narrow=narrow, sft_half=\
sft_half)
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
paddle.load(decoder_load_path)['params_ema'])
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
self.condition_scale = nn.LayerList()
self.condition_shift = nn.LayerList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2 ** i}']
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2D(out_channels, sft_out_channels, 3, 1, 1)))
self.condition_shift.append(
nn.Sequential(
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2D(out_channels, sft_out_channels, 3, 1, 1)))
def forward(self,
x,
return_latents=False,
return_rgb=True,
randomize_noise=True):
"""Forward function for GFPGANv1Clean.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
feat = paddle.nn.functional.leaky_relu(self.conv_body_first(x),
negative_slope=0.2)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = paddle.nn.functional.leaky_relu(self.final_conv(feat),
negative_slope=0.2)
style_code = self.final_linear(feat.reshape([feat.shape[0], -1]))
if self.different_w:
style_code = style_code.reshape(
[style_code.shape[0], -1, self.num_style_feat])
for i in range(self.log_size - 2):
feat = feat + unet_skips[i]
feat = self.conv_body_up[i](feat)
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
image, _ = self.stylegan_decoder(styles=[style_code],
conditions=conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise)
if return_latents:
return image, _
else:
return image, out_rgbs
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import paddle
from paddle import nn
from paddle.nn import functional as F
class NormStyleCode(nn.Layer):
def forward(self, x):
"""Normalize the style codes.
Args:
x (Tensor): Style codes with shape (b, c).
Returns:
Tensor: Normalized tensor.
"""
return x * paddle.rsqrt(paddle.mean(x ** 2, axis=1, keepdim=\
True) + 1e-08)
class ModulatedConv2d(nn.Layer):
"""Modulated Conv2d used in StyleGAN2.
There is no bias in ModulatedConv2d.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
eps=1e-08):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
self.sample_mode = sample_mode
self.eps = eps
self.modulation = nn.Linear(num_style_feat, in_channels, bias_attr=True)
# default_init_weights(self.modulation, scale=1, bias_fill=1, a=0,
# mode='fan_in', nonlinearity='linear')
x=paddle.randn(shape=[1, out_channels, in_channels, kernel_size, kernel_size],dtype='float32')/math. \
sqrt(in_channels * kernel_size ** 2)
self.weight = paddle.create_parameter(
shape=x.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Assign(x))
self.weight.stop_gradient = False
self.padding = kernel_size // 2
def forward(self, x, style):
"""Forward function.
Args:
x (Tensor): Tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
Returns:
Tensor: Modulated tensor after convolution.
"""
b, c, h, w = x.shape
style = self.modulation(style).reshape([b, 1, c, 1, 1])
weight = self.weight * style
if self.demodulate:
demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.reshape([b, self.out_channels, 1, 1, 1])
weight = weight.reshape(
[b * self.out_channels, c, self.kernel_size, self.kernel_size])
if self.sample_mode == 'upsample':
x = F.interpolate(x,
scale_factor=2,
mode='bilinear',
align_corners=False)
elif self.sample_mode == 'downsample':
x = F.interpolate(x,
scale_factor=0.5,
mode='bilinear',
align_corners=False)
b, c, h, w = x.shape
x = x.reshape([1, b * c, h, w])
out = paddle.nn.functional.conv2d(x,
weight,
padding=self.padding,
groups=b)
out = out.reshape([b, self.out_channels, *out.shape[2:4]])
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, \
out_channels={self.out_channels}, \
kernel_size={self.kernel_size}, \
demodulate={self.demodulate}, \
sample_mode={self.sample_mode})')
class StyleConv(nn.Layer):
"""Style conv used in StyleGAN2.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode)
x = paddle.zeros([1], dtype="float32")
self.weight = paddle.create_parameter(
x.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Assign(
x)) # for noise injection
x = paddle.zeros([1, out_channels, 1, 1], dtype="float32")
self.bias = paddle.create_parameter(
x.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Assign(x))
self.activate = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x, style, noise=None):
out = self.modulated_conv(x, style) * 2**0.5
if noise is None:
b, _, h, w = out.shape
noise = paddle.normal(shape=[b, 1, h, w])
out = out + self.weight * noise
out = out + self.bias
out = self.activate(out)
return out
class ToRGB(nn.Layer):
"""To RGB (image space) from features.
Args:
in_channels (int): Channel number of input.
num_style_feat (int): Channel number of style features.
upsample (bool): Whether to upsample. Default: True.
"""
def __init__(self, in_channels, num_style_feat, upsample=True):
super(ToRGB, self).__init__()
self.upsample = upsample
self.modulated_conv = ModulatedConv2d(in_channels,
3,
kernel_size=1,
num_style_feat=num_style_feat,
demodulate=False,
sample_mode=None)
x = paddle.zeros(shape=[1, 3, 1, 1], dtype='float32')
self.bias = paddle.create_parameter(
shape=x.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Assign(x))
self.bias.stop_gradient = False
def forward(self, x, style, skip=None):
"""Forward function.
Args:
x (Tensor): Feature tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
skip (Tensor): Base/skip tensor. Default: None.
Returns:
Tensor: RGB images.
"""
out = self.modulated_conv(x, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(skip,
scale_factor=2,
mode='bilinear',
align_corners=False)
out = out + skip
return out
class ConstantInput(nn.Layer):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
x = paddle.randn(shape=[1, num_channel, size, size], dtype='float32')
self.weight = paddle.create_parameter(
shape=x.shape,
dtype='float32',
default_initializer=paddle.nn.initializer.Assign(x))
self.weight.stop_gradient = False
def forward(self, batch):
out = paddle.tile(self.weight, repeat_times=[batch, 1, 1, 1])
return out
class StyleGAN2GeneratorClean(nn.Layer):
"""Clean version of StyleGAN2 Generator.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
narrow=1):
super(StyleGAN2GeneratorClean, self).__init__()
self.num_style_feat = num_style_feat
style_mlp_layers = [NormStyleCode()]
for i in range(num_mlp):
style_mlp_layers.extend([
nn.Linear(num_style_feat, num_style_feat, bias_attr=True),
nn.LeakyReLU(negative_slope=0.2)
])
self.style_mlp = nn.Sequential(*style_mlp_layers)
# default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2,
# mode='fan_in', nonlinearity='leaky_relu')
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
'16': int(512 * narrow),
'32': int(512 * narrow),
'64': int(256 * channel_multiplier * narrow),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
'1024': int(16 * channel_multiplier * narrow)
}
self.channels = channels
self.constant_input = ConstantInput(channels['4'], size=4)
self.style_conv1 = StyleConv(channels['4'],
channels['4'],
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None)
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
self.log_size = int(math.log(out_size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.num_latent = self.log_size * 2 - 2
self.style_convs = nn.LayerList()
self.to_rgbs = nn.LayerList()
self.noises = nn.Layer()
in_channels = channels['4']
for layer_idx in range(self.num_layers):
resolution = 2**((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}',
paddle.randn(shape=shape))
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2 ** i}']
self.style_convs.append(StyleConv(in_channels, out_channels,
kernel_size=3, num_style_feat=num_style_feat, demodulate=\
True, sample_mode='upsample'))
self.style_convs.append(StyleConv(out_channels, out_channels,
kernel_size=3, num_style_feat=num_style_feat, demodulate=\
True, sample_mode=None))
self.to_rgbs.append(
ToRGB(out_channels, num_style_feat, upsample=True))
in_channels = out_channels
def make_noise(self):
"""Make noise for noise injection."""
device = self.constant_input.weight.device
noises = [paddle.randn(shape=[1, 1, 4, 4])]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(paddle.randn(shape=[1, 1, 2**i, 2**i]))
return noises
def get_latent(self, x):
return self.style_mlp(x)
def mean_latent(self, num_latent):
latent_in = paddle.randn(shape=[num_latent, self.num_style_feat])
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
def forward(self,
styles,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2GeneratorClean.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f'noise{i}')
for i in range(self.num_layers)
]
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_truncation
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
elif len(styles) == 2:
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latent - inject_index, 1)
latent = paddle.concat([latent1, latent2], axis=1)
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2],
self.style_convs[1::2],
noise[1::2],
noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
此差异已折叠。
此差异已折叠。
===========================train_params===========================
model_name:GFPGAN
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=3
pretrained_model:null
train_model_name:gfpgan_ffhq1024*/*checkpoint.pdparams
train_infer_img_dir:null
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/gfpgan_ffhq1024.yaml --seed 123 -o log_config.interval=1 snapshot_config.interval=10
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/gfpgan_ffhq1024.yaml --inputs_size="1,3,512,512" --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/stylegan2/stylegan2model_gen
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type gfpgan --seed 123 -c configs/gfpgan_ffhq1024.yaml --output_path test_tipc/output/ -o validate=None
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:False
null:null
===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32
epoch:100
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[1, 512]}, {float32,[1]}]
......@@ -76,6 +76,11 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
cd ./data/ && unzip -q singan-official_images.zip && cd ../
mkdir -p ./data/singan
mv ./data/SinGAN-official_images/Images/stone.png ./data/singan ;;
GFPGAN)
rm -rf ./data/gfpgan*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/gfpgan_tipc_data.zip --no-check-certificate
mkdir -p ./data/gfpgan_data
cd ./data/ && unzip -q gfpgan_tipc_data.zip -d gfpgan_data/ && cd ../ ;;
esac
elif [ ${MODE} = "whole_train_whole_infer" ];then
if [ ${model_name} == "Pix2pix" ]; then
......
......@@ -334,6 +334,15 @@ def main():
metric_file = os.path.join(args.output_path, "singan/metric.txt")
for metric in metrics.values():
metric.update(prediction, data['A'])
elif model_type == 'gfpgan':
input_handles[0].copy_from_cpu(data['lq'].numpy())
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction, min_max)
save_image(
image_numpy,
os.path.join(args.output_path, "gfpgan/{}.png".format(i)))
elif model_type == "swinir":
lq = data[1].numpy()
_, _, h_old, w_old = lq.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册