未验证 提交 9a31bdf9 编写于 作者: B bitcjm 提交者: GitHub

add the model file and (#634)

* add the model files and tipc of GPEN

* add the model files and tipc of GPEN
上级 1e1e2ad2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append(".")
import argparse
import paddle
from ppgan.apps import GPENPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--test_img",
type=str,
default='data/gpen/lite_data/15006.png',
help="path of test image")
parser.add_argument("--model_type",
type=str,
default=None,
help="type of model for loading pretrained model")
parser.add_argument("--seed",
type=int,
default=None,
help="sample random seed for model's image generation")
parser.add_argument("--size",
type=int,
default=256,
help="resolution of output image")
parser.add_argument("--style_dim",
type=int,
default=512,
help="number of style dimension")
parser.add_argument("--n_mlp",
type=int,
default=8,
help="number of mlp layer depth")
parser.add_argument("--channel_multiplier",
type=int,
default=1,
help="number of channel multiplier")
parser.add_argument("--narrow",
type=float,
default=0.5,
help="number of channel narrow")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = GPENPredictor(output_path=args.output_path,
weight_path=args.weight_path,
model_type=args.model_type,
seed=args.seed,
size=args.size,
style_dim=args.style_dim,
n_mlp=args.n_mlp,
narrow=args.narrow,
channel_multiplier=args.channel_multiplier)
predictor.run(args.test_img)
total_iters: 200000
output_dir: output_dir
find_unused_parameters: True
model:
name: GPENModel
generator:
name: GPEN
size: 256
style_dim: 512
n_mlp: 8
channel_multiplier: 1
narrow: 0.5
discriminator:
name: GPENDiscriminator
size: 256
channel_multiplier: 1
narrow: 0.5
export_model:
- {name: 'g_ema', inputs_num: 1}
dataset:
train:
name: GPENDataset
dataroot: data/ffhq/images256x256/
num_workers: 0
batch_size: 2 #1gpus
size: 256
test:
name: GPENDataset
dataroot: data/ffhq/images256x256/
num_workers: 0
batch_size: 1
size: 256
amount: 100
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: 0.002
periods: [500000, 500000, 500000, 500000]
restart_weights: [1, 1, 1, 1]
eta_min: 0.002
optimizer:
optimG:
name: Adam
net_names:
- netG
beta1: 0.9
beta2: 0.99
optimD:
name: Adam
net_names:
- netD
beta1: 0.9
beta2: 0.99
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5000
validate:
interval: 5000
save_img: false
metrics:
fid:
name: FID
batch_size: 1
English | [Chinese](../../zh_CN/tutorials/gpen.md)
## GPEN Blind Face Restoration Model
## 1、Introduction
The GPEN model is a blind face restoration model. The author embeds the decoder of StyleGAN V2 proposed by the previous model as the decoder of GPEN; reconstructs a simple encoder with DNN to provide input for the decoder. In this way, while retaining the excellent performance of the StyleGAN V2 decoder, the function of the model is changed from image style conversion to blind face restoration. The overall structure of the model is shown in the following figure:
![img](https://user-images.githubusercontent.com/23252220/168281766-a0972bd3-243e-4fc7-baa5-e458ef0946ce.jpg)
For a more detailed introduction to the model, and refer to the repo, you can view the following AI Studio project [link]([GPEN Blind Face Repair Model Reproduction - Paddle AI Studio (baidu.com)](https://aistudio.baidu.com/ The latest version of aistudio/projectdetail/3936241?contributionType=1)).
## 2、Ready to work
### 2.1 Dataset Preparation
The GPEN model training set is the classic FFHQ face data set, with a total of 70,000 high-resolution 1024 x 1024 high-resolution face pictures, and the test set is the CELEBA-HQ data set, with a total of 2,000 high-resolution face pictures. For details, please refer to **Dataset URL:** [FFHQ](https://github.com/NVlabs/ffhq-dataset), [CELEBA-HQ](https://github.com/tkarras/progressive_growing_of_gans). The specific download links are given below:
**Original dataset download address:**
**FFHQ :** https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open
**CELEBA-HQ:** https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs?resourcekey=0-arAVTUfW9KRhN-irJchVKQ&usp=sharing
Since the original FFHQ dataset is too large, you can also download the 256-resolution FFHQ dataset from the following link:
https://paddlegan.bj.bcebos.com/datasets/images256x256.tar
**After downloading, the file organization is as follows**
```
|-- data/GPEN
|-- ffhq/images256x256/
|-- 00000
|-- 00000.png
|-- 00001.png
|-- ......
|-- 00999.png
|-- 01000
|-- ......
|-- ......
|-- 69000
|-- ......
|-- 69999.png
|-- test
|-- 2000张png图片
```
Please modify the dataroot parameters of dataset train and test in the configs/gpen_256_ffhq.yaml configuration file to your training set and test set path.
### 2.2 Model preparation
**Model parameter file and training log download address:**
link:https://paddlegan.bj.bcebos.com/models/gpen.zip
Download the model parameters and test images from the link and put them in the data/ folder in the project root directory. The specific file structure is as follows:
```
data/gpen/weights
|-- model_ir_se50.pdparams
|-- weight_pretrain.pdparams
data/gpen/lite_data
```
## 3、Start using
### 3.1 model training
Enter the following code in the console to start training:
```shell
python tools/main.py -c configs/gpen_256_ffhq.yaml
```
The model only supports single-card training.
Model training needs to use paddle2.3 and above, and wait for paddle to implement the second-order operator related functions of elementwise_pow. The paddle2.2.2 version can run normally, but the model cannot be successfully trained because some loss functions will calculate the wrong gradient. . If an error is reported during training, training is not supported for the time being. You can skip the training part and directly use the provided model parameters for testing. Model evaluation and testing can use paddle2.2.2 and above.
### 3.2 Model evaluation
When evaluating the model, enter the following code in the console, using the downloaded model parameters mentioned above:
```shell
python tools/main.py -c configs/gpen_256_ffhq.yaml -o dataset.test.amount=2000 --load data/gpen/weights/weight_pretrain.pdparams --evaluate-only
```
If you want to test on your own provided model, please modify the path after --load .
### 3.3 Model prediction
#### 3.3.1 Export generator weights
After training, you need to use ``tools/extract_weight.py`` to extract the weights of the generator from the trained model (including the generator and discriminator) for inference to `applications/tools/gpen.py` to achieve Various applications of the GPEN model. Enter the following command to extract the weights of the generator:
```bash
python tools/extract_weight.py data/gpen/weights/weight_pretrain.pdparams --net-name g_ema --output data/gpen/weights/g_ema.pdparams
```
#### 3.3.2 Process a single image
After extracting the weights of the generator, enter the following command to test the images under the --test_img path. Modifying the --seed parameter can generate different degraded images to show richer effects. You can modify the path after --test_img to any image you want to test. If no weight is provided after the --weight_path parameter, the trained model weights will be automatically downloaded for testing.
```bash
python applications/tools/gpen.py --test_img data/gpen/lite_data/15006.png --seed=100 --weight_path data/gpen/weights/g_ema.pdparams --model_type gpen-ffhq-256
```
The following are the sample images and the corresponding inpainted images, from left to right, the degraded image, the generated image, and the original clear image:
<p align='center'>
<img src="https://user-images.githubusercontent.com/23252220/168281788-39c08e86-2dc3-487f-987d-93489934c14c.png" height="256px" width='768px' >
An example output is as follows:
```
result saved in : output_dir/gpen_predict.png
FID: 92.11730631094356
PSNR:19.014782083825743
```
## 4. Tipc
### 4.1 Export the inference model
```bash
python tools/export_model.py -c configs/gpen_256_ffhq.yaml --inputs_size=1,3,256,256 --load data/gpen/weights/weight_pretrain.pdparams
```
The above command will generate the model structure file `gpenmodel_g_ema.pdmodel` and model weight files `gpenmodel_g_ema.pdiparams` and `gpenmodel_g_ema.pdiparams.info` files required for prediction, which are stored in the `inference_model/` directory. You can also modify the parameters after --load to the model parameter file you want to test.
### 4.2 Inference with a prediction engine
```bash
python tools/inference.py --model_type GPEN --seed 100 -c configs/gpen_256_ffhq.yaml -o dataset.test.dataroot="./data/gpen/lite_data/" --output_path test_tipc/output/ --model_path inference_model/gpenmodel_g_ema
```
At the end of the inference, the repaired image generated by the model will be saved in the test_tipc/output/GPEN directory by default, and the FID value obtained by the test will be output in test_tipc/output/GPEN/metric.txt.
The default output is as follows:
```
Metric fid: 187.0158
```
Note: Since the operation of degrading high-definition pictures has a certain degree of randomness, the results of each test will be different. In order to ensure that the test results are consistent, here I fixed the random seed, so that the same degradation operation is performed on the image for each test.
### 4.3 Call the script to complete the training and push test in two steps
To invoke the `lite_train_lite_infer` mode of the foot test base training prediction function, run:
```shell
# Corrected format of sh file
sed -i 's/\r//' test_tipc/prepare.sh
sed -i 's/\r//' test_tipc/test_train_inference_python.sh
sed -i 's/\r//' test_tipc/common_func.sh
# prepare data
bash test_tipc/prepare.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer'
# run the test
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer'
```
## 5、References
```
@misc{2021GAN,
title={GAN Prior Embedded Network for Blind Face Restoration in the Wild},
author={ Yang, T. and Ren, P. and Xie, X. and Zhang, L. },
year={2021},
archivePrefix={CVPR},
primaryClass={cs.CV}
}
```
[English](../../en_US/tutorials/gpen.md) | 中文
## GPEN 盲人脸修复模型
## 1、简介
GPEN模型是一个盲人脸修复模型。作者将前人提出的 StyleGAN V2 的解码器嵌入模型,作为GPEN的解码器;用DNN重新构建了一种简单的编码器,为解码器提供输入。这样模型在保留了 StyleGAN V2 解码器优秀的性能的基础上,将模型的功能由图像风格转换变为了盲人脸修复。模型的总体结构如下图所示:
![img](https://user-images.githubusercontent.com/23252220/168281766-a0972bd3-243e-4fc7-baa5-e458ef0946ce.jpg)
对模型更详细的介绍,和参考repo可查看以下AI Studio项目[链接]([GPEN盲人脸修复模型复现 - 飞桨AI Studio (baidu.com)](https://aistudio.baidu.com/aistudio/projectdetail/3936241?contributionType=1))的最新版本。
## 2、准备工作
### 2.1 数据集准备
GPEN模型训练集是经典的FFHQ人脸数据集,共70000张1024 x 1024高分辨率的清晰人脸图片,测试集是CELEBA-HQ数据集,共2000张高分辨率人脸图片。详细信息可以参考**数据集网址:** [FFHQ](https://github.com/NVlabs/ffhq-dataset)[CELEBA-HQ](https://github.com/tkarras/progressive_growing_of_gans) 。以下给出了具体的下载链接:
**原数据集下载地址:**
**FFHQ :** https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open
**CELEBA-HQ:** https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs?resourcekey=0-arAVTUfW9KRhN-irJchVKQ&usp=sharing
由于FFHQ原数据集过大,也可以从以下链接下载256分辨率的FFHQ数据集:
https://paddlegan.bj.bcebos.com/datasets/images256x256.tar
**下载后,文件参考组织形式如下**
```
|-- data/GPEN
|-- ffhq/images256x256/
|-- 00000
|-- 00000.png
|-- 00001.png
|-- ......
|-- 00999.png
|-- 01000
|-- ......
|-- ......
|-- 69000
|-- ......
|-- 69999.png
|-- test
|-- 2000张png图片
```
请修改configs/gpen_256_ffhq.yaml配置文件中dataset的train和test的dataroot参数为你的训练集和测试集路径。
### 2.2 模型准备
**模型参数文件及训练日志下载地址:**
链接:https://paddlegan.bj.bcebos.com/models/gpen.zip
从链接中下载模型参数和测试图片,并放到项目根目录下的data/文件夹下,具体文件结构如下所示:
**文件结构**
```
data/gpen/weights
|-- model_ir_se50.pdparams #计算id_loss需要加载的facenet的模型参数文件
|-- weight_pretrain.pdparams #256分辨率的包含生成器和判别器的模型参数文件,其中只有生成器的参数是训练好的参数,参 #数文件的格式与3.1训练过程中保存的参数文件格式相同。3.2、3.3.1、4.1也需要用到该参数文件
data/gpen/lite_data
```
## 3、开始使用
### 3.1 模型训练
在控制台输入以下代码,开始训练:
```shell
python tools/main.py -c configs/gpen_256_ffhq.yaml
```
模型只支持单卡训练。
模型训练需使用paddle2.3及以上版本,且需等paddle实现elementwise_pow 的二阶算子相关功能,使用paddle2.2.2版本能正常运行,但因部分损失函数会求出错误梯度,导致模型无法训练成功。如训练时报错则暂不支持进行训练,可跳过训练部分,直接使用提供的模型参数进行测试。模型评估和测试使用paddle2.2.2及以上版本即可。
### 3.2 模型评估
对模型进行评估时,在控制台输入以下代码,下面代码中使用上面提到的下载的模型参数:
```shell
python tools/main.py -c configs/gpen_256_ffhq.yaml -o dataset.test.amount=2000 --load data/gpen/weights/weight_pretrain.pdparams --evaluate-only
```
如果要在自己提供的模型上进行测试,请修改 --load 后面的路径。
### 3.3 模型预测
#### 3.3.1 导出生成器权重
训练结束后,需要使用 ``tools/extract_weight.py`` 来从训练模型(包含了生成器和判别器)中提取生成器的权重来给`applications/tools/gpen.py`进行推理,以实现GPEN模型的各种应用。输入以下命令来提取生成器的权重:
```bash
python tools/extract_weight.py data/gpen/weights/weight_pretrain.pdparams --net-name g_ema --output data/gpen/weights/g_ema.pdparams
```
#### 3.3.2 对单张图像进行处理
提取完生成器的权重后,输入以下命令可对--test_img路径下图片进行测试。修改--seed参数,可生成不同的退化图像,展示出更丰富的效果。可修改--test_img后的路径为你想测试的任意图片。如--weight_path参数后不提供权重,则会自动下载训练好的模型权重进行测试。
```bash
python applications/tools/gpen.py --test_img data/gpen/lite_data/15006.png --seed=100 --weight_path data/gpen/weights/g_ema.pdparams --model_type gpen-ffhq-256
```
以下是样例图片和对应的修复图像,从左到右依次是退化图像、生成的图像和原始清晰图像:
<p align='center'>
<img src="https://user-images.githubusercontent.com/23252220/168281788-39c08e86-2dc3-487f-987d-93489934c14c.png" height="256px" width='768px' >
输出示例如下:
```
result saved in : output_dir/gpen_predict.png
FID: 92.11730631094356
PSNR:19.014782083825743
```
## 4. Tipc
### 4.1 导出inference模型
```bash
python tools/export_model.py -c configs/gpen_256_ffhq.yaml --inputs_size=1,3,256,256 --load data/gpen/weights/weight_pretrain.pdparams
```
上述命令将生成预测所需的模型结构文件`gpenmodel_g_ema.pdmodel`和模型权重文件`gpenmodel_g_ema.pdiparams`以及`gpenmodel_g_ema.pdiparams.info`文件,均存放在`inference_model/`目录下。也可以修改--load 后的参数为你想测试的模型参数文件。
### 4.2 使用预测引擎推理
```bash
python tools/inference.py --model_type GPEN --seed 100 -c configs/gpen_256_ffhq.yaml -o dataset.test.dataroot="./data/gpen/lite_data/" --output_path test_tipc/output/ --model_path inference_model/gpenmodel_g_ema
```
推理结束会默认保存下模型生成的修复图像在test_tipc/output/GPEN目录下,并载test_tipc/output/GPEN/metric.txt中输出测试得到的FID值。
默认输出如下:
```
Metric fid: 187.0158
```
注:由于对高清图片进行退化的操作具有一定的随机性,所以每次测试的结果都会有所不同。为了保证测试结果一致,在这里我固定了随机种子,使每次测试时对图片都进行相同的退化操作。
### 4.3 调用脚本两步完成训推一体测试
测试基本训练预测功能的`lite_train_lite_infer`模式,运行:
```shell
# 修正脚本文件格式
sed -i 's/\r//' test_tipc/prepare.sh
sed -i 's/\r//' test_tipc/test_train_inference_python.sh
sed -i 's/\r//' test_tipc/common_func.sh
# 准备数据
bash test_tipc/prepare.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer'
# 运行测试
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer'
```
## 5、参考文献
```
@misc{2021GAN,
title={GAN Prior Embedded Network for Blind Face Restoration in the Wild},
author={ Yang, T. and Ren, P. and Xie, X. and Zhang, L. },
year={2021},
archivePrefix={CVPR},
primaryClass={cs.CV}
}
```
...@@ -35,3 +35,4 @@ from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \ ...@@ -35,3 +35,4 @@ from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \
BasiVSRPlusPlusPredictor, IconVSRPredictor, \ BasiVSRPlusPlusPredictor, IconVSRPredictor, \
PPMSVSRLargePredictor) PPMSVSRLargePredictor)
from .singan_predictor import SinGANPredictor from .singan_predictor import SinGANPredictor
from .gpen_predictor import GPENPredictor
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import random
import numpy as np
import paddle
import sys
sys.path.append(".")
from .base_predictor import BasePredictor
from ppgan.datasets.gpen_dataset import GFPGAN_degradation
from ppgan.models.generators import GPEN
from ppgan.metrics.fid import FID
from ppgan.utils.download import get_path_from_url
import cv2
import warnings
model_cfgs = {
'gpen-ffhq-256': {
'model_urls':
'https://paddlegan.bj.bcebos.com/models/gpen-ffhq-256-generator.pdparams',
'size': 256,
'style_dim': 512,
'n_mlp': 8,
'channel_multiplier': 1,
'narrow': 0.5
}
}
def psnr(pred, gt):
pred = paddle.clip(pred, min=0, max=1)
gt = paddle.clip(gt, min=0, max=1)
imdff = np.asarray(pred - gt)
rmse = math.sqrt(np.mean(imdff**2))
if rmse == 0:
return 100
return 20 * math.log10(1.0 / rmse)
def data_loader(path, size=256):
degrader = GFPGAN_degradation()
img_gt = cv2.imread(path, cv2.IMREAD_COLOR)
img_gt = cv2.resize(img_gt, (size, size), interpolation=cv2.INTER_NEAREST)
img_gt = img_gt.astype(np.float32) / 255.
img_gt, img_lq = degrader.degrade_process(img_gt)
img_gt = (paddle.to_tensor(img_gt) - 0.5) / 0.5
img_lq = (paddle.to_tensor(img_lq) - 0.5) / 0.5
img_gt = img_gt.transpose([2, 0, 1]).flip(0).unsqueeze(0)
img_lq = img_lq.transpose([2, 0, 1]).flip(0).unsqueeze(0)
return np.array(img_lq).astype('float32'), np.array(img_gt).astype(
'float32')
class GPENPredictor(BasePredictor):
def __init__(self,
output_path='output_dir',
weight_path=None,
model_type=None,
seed=100,
size=256,
style_dim=512,
n_mlp=8,
channel_multiplier=1,
narrow=0.5):
self.output_path = output_path
self.size = size
if weight_path is None:
if model_type in model_cfgs.keys():
weight_path = get_path_from_url(
model_cfgs[model_type]['model_urls'])
size = model_cfgs[model_type].get('size', size)
style_dim = model_cfgs[model_type].get('style_dim', style_dim)
n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp)
channel_multiplier = model_cfgs[model_type].get(
'channel_multiplier', channel_multiplier)
narrow = model_cfgs[model_type].get('narrow', narrow)
checkpoint = paddle.load(weight_path)
else:
raise ValueError(
'Predictor need a weight path or a pretrained model type')
else:
checkpoint = paddle.load(weight_path)
warnings.filterwarnings("always")
self.generator = GPEN(size, style_dim, n_mlp, channel_multiplier,
narrow)
self.generator.set_state_dict(checkpoint)
self.generator.eval()
if seed is not None:
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
def run(self, img_path):
os.makedirs(self.output_path, exist_ok=True)
input_array, target_array = data_loader(img_path, self.size)
input_tensor = paddle.to_tensor(input_array)
target_tensor = paddle.to_tensor(target_array)
FID_model = FID(use_GPU=True)
with paddle.no_grad():
output, _ = self.generator(input_tensor)
psnr_score = psnr(target_tensor, output)
FID_model.update(output, target_tensor)
fid_score = FID_model.accumulate()
input_tensor = input_tensor.transpose([0, 2, 3, 1])
target_tensor = target_tensor.transpose([0, 2, 3, 1])
output = output.transpose([0, 2, 3, 1])
sample_result = paddle.concat(
(input_tensor[0], output[0], target_tensor[0]), 1)
sample = cv2.cvtColor((sample_result.numpy() + 1) / 2 * 255,
cv2.COLOR_RGB2BGR)
file_name = self.output_path + '/gpen_predict.png'
cv2.imwrite(file_name, sample)
print(f"result saved in : {file_name}")
print(f"\tFID: {fid_score}\n\tPSNR:{psnr_score}")
...@@ -30,3 +30,4 @@ from .vsr_vimeo90k_dataset import VSRVimeo90KDataset ...@@ -30,3 +30,4 @@ from .vsr_vimeo90k_dataset import VSRVimeo90KDataset
from .vsr_folder_dataset import VSRFolderDataset from .vsr_folder_dataset import VSRFolderDataset
from .photopen_dataset import PhotoPenDataset from .photopen_dataset import PhotoPenDataset
from .empty_dataset import EmptyDataset from .empty_dataset import EmptyDataset
from .gpen_dataset import GPENDataset
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import os
import numpy as np
import paddle
from paddle.io import Dataset
import cv2
from .builder import DATASETS
import math
import random
logger = logging.getLogger(__name__)
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
"""Generate Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
else:
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
return noise
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_gaussian_noise(img, sigma, gray_noise)
def random_add_gaussian_noise(img,
sigma_range=(0, 1.0),
gray_prob=0,
clip=True,
rounds=False):
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def add_jpg_compression(img, quality=90):
"""Add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality (float): JPG compression quality. 0 for lowest quality, 100 for
best quality. Default: 90.
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
img = np.clip(img, 0, 1)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
return img
def random_add_jpg_compression(img, quality_range=(90, 100)):
"""Randomly add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality_range (tuple[float] | list[float]): JPG compression quality
range. 0 for lowest quality, 100 for best quality.
Default: (90, 100).
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
quality = int(np.random.uniform(quality_range[0], quality_range[1]))
return add_jpg_compression(img, quality)
def random_mixed_kernels(kernel_list,
kernel_prob,
kernel_size=21,
sigma_x_range=(0.6, 5),
sigma_y_range=(0.6, 5),
rotation_range=(-math.pi, math.pi),
betag_range=(0.5, 8),
betap_range=(0.5, 8),
noise_range=None):
"""Randomly generate mixed kernels.
Args:
kernel_list (tuple): a list name of kernel types,
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
'plateau_aniso']
kernel_prob (tuple): corresponding kernel probability for each
kernel type
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
kernel_type = random.choices(kernel_list, kernel_prob)[0]
if kernel_type == 'iso':
kernel = random_bivariate_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=noise_range,
isotropic=True)
elif kernel_type == 'aniso':
kernel = random_bivariate_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=noise_range,
isotropic=False)
return kernel
def random_bivariate_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
kernel = bivariate_Gaussian(kernel_size,
sigma_x,
sigma_y,
rotation,
isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0],
noise_range[1],
size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_Gaussian(kernel_size,
sig_x,
sig_y,
theta,
grid=None,
isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
yy.reshape(kernel_size * kernel_size,
1))).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
class GFPGAN_degradation(object):
def __init__(self):
self.kernel_list = ['iso', 'aniso']
self.kernel_prob = [0.5, 0.5]
self.blur_kernel_size = 41
self.blur_sigma = [0.1, 10]
self.downsample_range = [0.8, 8]
self.noise_range = [0, 20]
self.jpeg_range = [60, 100]
self.gray_prob = 0.2
self.color_jitter_prob = 0.0
self.color_jitter_pt_prob = 0.0
self.shift = 20 / 255.
def degrade_process(self, img_gt):
if random.random() > 0.5:
img_gt = cv2.flip(img_gt, 1)
h, w = img_gt.shape[:2]
# random color jitter
if np.random.uniform() < self.color_jitter_prob:
jitter_val = np.random.uniform(-self.shift, self.shift,
3).astype(np.float32)
img_gt = img_gt + jitter_val
img_gt = np.clip(img_gt, 0, 1)
# random grayscale
if np.random.uniform() < self.gray_prob:
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
# ------------------------ generate lq image ------------------------ #
# blur
kernel = random_mixed_kernels(self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
noise_range=None)
img_lq = cv2.filter2D(img_gt, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range[0],
self.downsample_range[1])
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)),
interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range is not None:
img_lq = random_add_gaussian_noise(img_lq, self.noise_range)
# jpeg compression
if self.jpeg_range is not None:
img_lq = random_add_jpg_compression(img_lq, self.jpeg_range)
# round and clip
img_lq = np.clip((img_lq * 255.0).round(), 0, 255) / 255.
# resize to original size
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
return img_gt, img_lq
@DATASETS.register()
class GPENDataset(Dataset):
"""
coco2017 dataset for LapStyle model
"""
def __init__(self, dataroot, size=256, amount=-1):
super(GPENDataset, self).__init__()
self.size = size
self.HQ_imgs = sorted(glob.glob(os.path.join(dataroot,
'*/*.*g')))[:amount]
self.length = len(self.HQ_imgs)
if self.length == 0:
self.HQ_imgs = sorted(glob.glob(os.path.join(dataroot,
'*.*g')))[:amount]
self.length = len(self.HQ_imgs)
print(self.length)
self.degrader = GFPGAN_degradation()
def __len__(self):
return self.length
def __getitem__(self, index):
"""Get training sample
return:
ci: content image with shape [C,W,H],
si: style image with shape [C,W,H],
ci_path: str
"""
img_gt = cv2.imread(self.HQ_imgs[index], cv2.IMREAD_COLOR)
img_gt = cv2.resize(img_gt, (self.size, self.size),
interpolation=cv2.INTER_AREA)
# BFR degradation
img_gt = img_gt.astype(np.float32) / 255.
img_gt, img_lq = self.degrader.degrade_process(img_gt)
img_gt = (paddle.to_tensor(img_gt) - 0.5) / 0.5
img_lq = (paddle.to_tensor(img_lq) - 0.5) / 0.5
img_gt = img_gt.transpose([2, 0, 1]).flip(0)
img_lq = img_lq.transpose([2, 0, 1]).flip(0)
return np.array(img_lq).astype('float32'), np.array(img_gt).astype(
'float32')
...@@ -37,3 +37,4 @@ from .msvsr_model import MultiStageVSRModel ...@@ -37,3 +37,4 @@ from .msvsr_model import MultiStageVSRModel
from .singan_model import SinGANModel from .singan_model import SinGANModel
from .rcan_model import RCANModel from .rcan_model import RCANModel
from .prenet_model import PReNetModel from .prenet_model import PReNetModel
from .gpen_model import GPENModel
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import paddle
import paddle.nn as nn
class Flatten(nn.Layer):
def forward(self, input):
return paddle.reshape(input, [input.shape[0], -1])
def l2_norm(input, axis=1):
norm = paddle.norm(input, 2, axis, True)
output = paddle.divide(input, norm)
return output
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
""" A named tuple describing a ResNet block. """
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
else:
raise ValueError(
"Invalid number of layers: {}. Must be one of [50, 100, 152]".
format(num_layers))
return blocks
class SEModule(nn.Layer):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc1 = nn.Conv2D(channels,
channels // reduction,
kernel_size=1,
padding=0,
bias_attr=False)
self.relu = nn.ReLU()
self.fc2 = nn.Conv2D(channels // reduction,
channels,
kernel_size=1,
padding=0,
bias_attr=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(nn.Layer):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = nn.MaxPool2D(1, stride)
else:
self.shortcut_layer = nn.Sequential(
nn.Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False),
nn.BatchNorm2D(depth))
self.res_layer = nn.Sequential(
nn.BatchNorm2D(in_channel),
nn.Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False),
nn.PReLU(depth),
nn.Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False),
nn.BatchNorm2D(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(nn.Layer):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = nn.MaxPool2D(1, stride)
else:
self.shortcut_layer = nn.Sequential(
nn.Conv2D(in_channel, depth, (1, 1), stride, bias_attr=False),
nn.BatchNorm2D(depth))
self.res_layer = nn.Sequential(
nn.BatchNorm2D(in_channel),
nn.Conv2D(in_channel, depth, (3, 3), (1, 1), 1, bias_attr=False),
nn.PReLU(depth),
nn.Conv2D(depth, depth, (3, 3), stride, 1, bias_attr=False),
nn.BatchNorm2D(depth), SEModule(depth, 16))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
from .model_irse import Backbone
from paddle.vision.transforms import Resize
from ..builder import CRITERIONS
from ppgan.utils.download import get_path_from_url
model_cfgs = {
'model_urls':
'https://paddlegan.bj.bcebos.com/models/model_ir_se50.pdparams',
}
@CRITERIONS.register()
class IDLoss(paddle.nn.Layer):
def __init__(self, base_dir='./'):
super(IDLoss, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112,
num_layers=50,
drop_ratio=0.6,
mode='ir_se')
facenet_weights_path = os.path.join(base_dir, 'data/gpen/weights',
'model_ir_se50.pdparams')
if not os.path.isfile(facenet_weights_path):
facenet_weights_path = get_path_from_url(model_cfgs['model_urls'])
self.facenet.load_dict(paddle.load(facenet_weights_path))
self.face_pool = paddle.nn.AdaptiveAvgPool2D((112, 112))
self.facenet.eval()
def extract_feats(self, x):
_, _, h, w = x.shape
assert h == w
ss = h // 256
x = x[:, :, 35 * ss:-33 * ss, 32 * ss:-36 * ss]
transform = Resize(size=(112, 112))
for num in range(x.shape[0]):
mid_feats = transform(x[num]).unsqueeze(0)
if num == 0:
x_feats = mid_feats
else:
x_feats = paddle.concat([x_feats, mid_feats], axis=0)
x_feats = self.facenet(x_feats)
return x_feats
def forward(self, y_hat, y, x):
n_samples = x.shape[0]
y_feats = self.extract_feats(y)
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
loss += 1 - diff_target
count += 1
return loss / count
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Backbone(nn.Layer):
def __init__(self,
input_size,
num_layers,
mode='ir',
drop_ratio=0.4,
affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], "input_size should be 112 or 224"
assert num_layers in [50, 100,
152], "num_layers should be 50, 100 or 152"
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = paddle.nn.Sequential(
nn.Conv2D(3, 64, (3, 3), 1, 1, bias_attr=False), nn.BatchNorm2D(64),
nn.PReLU(64))
if input_size == 112:
self.output_layer = nn.Sequential(nn.BatchNorm2D(512),
nn.Dropout(drop_ratio), Flatten(),
nn.Linear(512 * 7 * 7, 512),
nn.BatchNorm1D(512))
else:
self.output_layer = nn.Sequential(nn.BatchNorm2D(512),
nn.Dropout(drop_ratio), Flatten(),
nn.Linear(512 * 14 * 14, 512),
nn.BatchNorm1D(512))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = nn.Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
...@@ -9,3 +9,4 @@ from .gradient_penalty import GradientPenalty ...@@ -9,3 +9,4 @@ from .gradient_penalty import GradientPenalty
from .builder import build_criterion from .builder import build_criterion
from .ssim import SSIM from .ssim import SSIM
from .IDLoss.id_loss import IDLoss
...@@ -17,7 +17,7 @@ from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification ...@@ -17,7 +17,7 @@ from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification
from .discriminator_ugatit import UGATITDiscriminator from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator from .dcdiscriminator import DCDiscriminator
from .discriminator_animegan import AnimeDiscriminator from .discriminator_animegan import AnimeDiscriminator
from .discriminator_styleganv2 import StyleGANv2Discriminator from .discriminator_styleganv2 import StyleGANv2Discriminator, GPENDiscriminator
from .syncnet import SyncNetColor from .syncnet import SyncNetColor
from .wav2lip_disc_qual import Wav2LipDiscQual from .wav2lip_disc_qual import Wav2LipDiscQual
from .discriminator_starganv2 import StarGANv2Discriminator from .discriminator_starganv2 import StarGANv2Discriminator
......
...@@ -28,6 +28,7 @@ from ...modules.upfirdn2d import Upfirdn2dBlur ...@@ -28,6 +28,7 @@ from ...modules.upfirdn2d import Upfirdn2dBlur
class ConvLayer(nn.Sequential): class ConvLayer(nn.Sequential):
def __init__( def __init__(
self, self,
in_channel, in_channel,
...@@ -72,6 +73,7 @@ class ConvLayer(nn.Sequential): ...@@ -72,6 +73,7 @@ class ConvLayer(nn.Sequential):
class ResBlock(nn.Layer): class ResBlock(nn.Layer):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__() super().__init__()
...@@ -112,6 +114,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None): ...@@ -112,6 +114,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
@DISCRIMINATORS.register() @DISCRIMINATORS.register()
class StyleGANv2Discriminator(nn.Layer): class StyleGANv2Discriminator(nn.Layer):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__() super().__init__()
...@@ -171,3 +174,71 @@ class StyleGANv2Discriminator(nn.Layer): ...@@ -171,3 +174,71 @@ class StyleGANv2Discriminator(nn.Layer):
out = self.final_linear(out) out = self.final_linear(out)
return out return out
@DISCRIMINATORS.register()
class GPENDiscriminator(nn.Layer):
def __init__(self,
size,
channel_multiplier=1,
narrow=0.5,
blur_kernel=[1, 3, 3, 1]):
super().__init__()
channels = {
4: int(512 * narrow),
8: int(512 * narrow),
16: int(512 * narrow),
32: int(512 * narrow),
64: int(256 * channel_multiplier * narrow),
128: int(128 * channel_multiplier * narrow),
256: int(64 * channel_multiplier * narrow),
512: int(32 * channel_multiplier * narrow),
1024: int(16 * channel_multiplier * narrow),
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4,
channels[4],
activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)
def forward(self, input):
out = self.convs(input)
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.reshape((group, -1, self.stddev_feat,
channel // self.stddev_feat, height, width))
stddev = paddle.sqrt(var(stddev, 0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
stddev = stddev.tile((group, 1, height, width))
out = paddle.concat([out, stddev], 1)
out = self.final_conv(out)
out = out.reshape((batch, -1))
out = self.final_linear(out)
return out
...@@ -41,3 +41,4 @@ from .msvsr import MSVSR ...@@ -41,3 +41,4 @@ from .msvsr import MSVSR
from .generator_singan import SinGANGenerator from .generator_singan import SinGANGenerator
from .rcan import RCAN from .rcan import RCAN
from .prenet import PReNet from .prenet import PReNet
from .gpen import GPEN
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/rosinality/stylegan2-pytorch
# MIT License
# Copyright (c) 2019 Kim Seonghyeon
import math
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppgan.modules.equalized import EqualLinear_gpen as EqualLinear
from ppgan.modules.fused_act import FusedLeakyReLU
from ppgan.modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur
class PixelNorm(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, inputs):
return inputs * paddle.rsqrt(
paddle.mean(inputs * inputs, 1, keepdim=True) + 1e-8)
class ModulatedConv2D(nn.Layer):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Upfirdn2dBlur(blur_kernel,
pad=(pad0, pad1),
upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * (kernel_size * kernel_size)
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = self.create_parameter(
(1, out_channel, in_channel, kernel_size, kernel_size),
default_initializer=nn.initializer.Normal())
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def __repr__(self):
return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})")
def forward(self, inputs, style):
batch, in_channel, height, width = inputs.shape
style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
weight = self.scale * self.weight * style
if self.demodulate:
demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
weight = weight.reshape((batch * self.out_channel, in_channel,
self.kernel_size, self.kernel_size))
if self.upsample:
inputs = inputs.reshape((1, batch * in_channel, height, width))
weight = weight.reshape((batch, self.out_channel, in_channel,
self.kernel_size, self.kernel_size))
weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
(batch * in_channel, self.out_channel, self.kernel_size,
self.kernel_size))
out = F.conv2d_transpose(inputs,
weight,
padding=0,
stride=2,
groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))
out = self.blur(out)
elif self.downsample:
inputs = self.blur(inputs)
_, _, height, width = inputs.shape
inputs = inputs.reshape((1, batch * in_channel, height, width))
out = F.conv2d(inputs, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))
else:
inputs = inputs.reshape((1, batch * in_channel, height, width))
out = F.conv2d(inputs, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))
return out
class NoiseInjection(nn.Layer):
def __init__(self, is_concat=False):
super().__init__()
self.weight = self.create_parameter(
(1, ), default_initializer=nn.initializer.Constant(0.0))
self.is_concat = is_concat
def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = paddle.randn((batch, 1, height, width))
if self.is_concat:
return paddle.concat([image, self.weight * noise], axis=1)
else:
return image + self.weight * noise
class ConstantInput(nn.Layer):
def __init__(self, channel, size=4):
super().__init__()
self.input = self.create_parameter(
(1, channel, size, size),
default_initializer=nn.initializer.Normal())
def forward(self, inputs):
batch = inputs.shape[0]
out = self.input.tile((batch, 1, 1, 1))
return out
class StyledConv(nn.Layer):
def __init__(self,
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
is_concat=False):
super().__init__()
self.conv = ModulatedConv2D(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
demodulate=demodulate,
)
self.noise = NoiseInjection(is_concat=is_concat)
self.activate = FusedLeakyReLU(out_channel *
2 if is_concat else out_channel)
def forward(self, inputs, style, noise=None):
out = self.conv(inputs, style)
out = self.noise(out, noise=noise)
out = self.activate(out)
return out
class ToRGB(nn.Layer):
def __init__(self,
in_channel,
style_dim,
upsample=True,
blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upfirdn2dUpsample(blur_kernel)
self.conv = ModulatedConv2D(in_channel,
3,
1,
style_dim,
demodulate=False)
self.bias = self.create_parameter((1, 3, 1, 1),
nn.initializer.Constant(0.0))
def forward(self, inputs, style, skip=None):
out = self.conv(inputs, style)
out = out + self.bias
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class StyleGANv2Generator(nn.Layer):
def __init__(self,
size,
style_dim,
n_mlp,
channel_multiplier=1,
narrow=0.5,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
is_concat=True):
super().__init__()
self.size = size
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(style_dim,
style_dim,
lr_mul=lr_mlp,
activation="fused_lrelu"))
self.style = nn.Sequential(*layers)
self.channels = {
4: int(512 * narrow),
8: int(512 * narrow),
16: int(512 * narrow),
32: int(512 * narrow),
64: int(256 * channel_multiplier * narrow),
128: int(128 * channel_multiplier * narrow),
256: int(64 * channel_multiplier * narrow),
512: int(32 * channel_multiplier * narrow),
1024: int(16 * channel_multiplier * narrow),
2048: int(8 * channel_multiplier * narrow)
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(self.channels[4],
self.channels[4],
3,
style_dim,
blur_kernel=blur_kernel,
is_concat=is_concat)
self.to_rgb1 = ToRGB(self.channels[4] *
2 if is_concat else self.channels[4],
style_dim,
upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.LayerList()
self.upsamples = nn.LayerList()
self.to_rgbs = nn.LayerList()
self.noises = nn.Layer()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
self.noises.register_buffer(f"noise_{layer_idx}",
paddle.randn(shape))
for i in range(3, self.log_size + 1):
out_channel = self.channels[2**i]
self.convs.append(
StyledConv(
in_channel * 2 if is_concat else in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
is_concat=is_concat,
))
self.convs.append(
StyledConv(out_channel * 2 if is_concat else out_channel,
out_channel,
3,
style_dim,
blur_kernel=blur_kernel,
is_concat=is_concat))
self.to_rgbs.append(
ToRGB(out_channel * 2 if is_concat else out_channel, style_dim))
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
self.is_concat = is_concat
def make_noise(self):
noises = [paddle.randn((1, 1, 2**2, 2**2))]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(paddle.randn((1, 1, 2**i, 2**i)))
return noises
def mean_latent(self, n_latent):
latent_in = paddle.randn((n_latent, self.style_dim))
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, inputs):
return self.style(inputs)
def get_mean_style(self):
mean_style = None
with paddle.no_grad():
for i in range(10):
style = self.mean_latent(1024)
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= 10
return mean_style
def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
'''
noise = [None] * (2 * (self.log_size - 2) + 1)
'''
noise = []
batch = styles[0].shape[0]
for i in range(self.n_mlp + 1):
size = 2**(i + 2)
noise.append(
paddle.create_parameter(
[batch, self.channels[size], size, size],
dtype='float32',
attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0),
trainable=True)))
if truncation < 1:
style_t = []
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
latent = styles[0].unsqueeze(1)
latent = paddle.tile(latent, repeat_times=[1, inject_index, 1])
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = paddle.tile(styles[0].unsqueeze(1),
repeat_times=[1, inject_index, 1])
latent2 = paddle.tile(
styles[1].unsqueeze(1),
repeat_times=[1, self.n_latent - inject_index, 1])
latent = paddle.concat([latent, latent2], 1)
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2],
self.convs[1::2],
noise[1::2],
noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
...@@ -12,60 +12,73 @@ ...@@ -12,60 +12,73 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# code was heavily based on https://github.com/yangxy/GPEN # code was heavily based on code was heavily based on https://github.com/yangxy/GPEN
import paddle import itertools
import paddle.nn as nn import paddle.nn as nn
import math import math
from ppgan.models.generators import StyleGANv2Generator from ppgan.models.generators.builder import GENERATORS
from ppgan.modules.equalized import EqualLinear_gpen as EqualLinear
from ppgan.models.generators.generator_gpen import StyleGANv2Generator
from ppgan.models.discriminators.discriminator_styleganv2 import ConvLayer from ppgan.models.discriminators.discriminator_styleganv2 import ConvLayer
from ppgan.modules.equalized import EqualLinear
@GENERATORS.register()
class GPEN(nn.Layer): class GPEN(nn.Layer):
def __init__( def __init__(
self, self,
size, size,
style_dim, style_dim,
n_mlp, n_mlp,
channel_multiplier=2, channel_multiplier=2,
narrow=1,
blur_kernel=[1, 3, 3, 1], blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01, lr_mlp=0.01,
is_concat=True,
): ):
super(GPEN, self).__init__() super(GPEN, self).__init__()
channels = { channels = {
4: 512, 4: int(512 * narrow),
8: 512, 8: int(512 * narrow),
16: 512, 16: int(512 * narrow),
32: 512, 32: int(512 * narrow),
64: 256 * channel_multiplier, 64: int(256 * channel_multiplier * narrow),
128: 128 * channel_multiplier, 128: int(128 * channel_multiplier * narrow),
256: 64 * channel_multiplier, 256: int(64 * channel_multiplier * narrow),
512: 32 * channel_multiplier, 512: int(32 * channel_multiplier * narrow),
1024: 16 * channel_multiplier, 1024: int(16 * channel_multiplier * narrow),
2048: int(8 * channel_multiplier * narrow)
} }
self.log_size = int(math.log(size, 2)) self.log_size = int(math.log(size, 2))
self.generator = StyleGANv2Generator(size, self.generator = StyleGANv2Generator(
style_dim, size,
n_mlp, style_dim,
channel_multiplier=channel_multiplier, n_mlp,
blur_kernel=blur_kernel, channel_multiplier=channel_multiplier,
lr_mlp=lr_mlp, narrow=narrow,
is_concat=True) blur_kernel=blur_kernel,
lr_mlp=lr_mlp,
is_concat=is_concat)
conv = [ConvLayer(3, channels[size], 1)] conv = [ConvLayer(3, channels[size], 1)]
self.ecd0 = nn.Sequential(*conv) self.ecd0 = nn.Sequential(*conv)
in_channel = channels[size] in_channel = channels[size]
self.names = ['ecd%d'%i for i in range(self.log_size-1)] self.names = ['ecd%d' % i for i in range(self.log_size - 1)]
for i in range(self.log_size, 2, -1): for i in range(self.log_size, 2, -1):
out_channel = channels[2 ** (i - 1)] out_channel = channels[2**(i - 1)]
conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)]
setattr(self, self.names[self.log_size-i+1], nn.Sequential(*conv)) setattr(self, self.names[self.log_size - i + 1],
nn.Sequential(*conv))
in_channel = out_channel in_channel = out_channel
self.final_linear = nn.Sequential(EqualLinear(channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4,
style_dim,
activation='fused_lrelu'))
def forward(self, def forward(
self,
inputs, inputs,
return_latents=False, return_latents=False,
inject_index=None, inject_index=None,
...@@ -74,15 +87,20 @@ class GPEN(nn.Layer): ...@@ -74,15 +87,20 @@ class GPEN(nn.Layer):
input_is_latent=False, input_is_latent=False,
): ):
noise = [] noise = []
for i in range(self.log_size-1): for i in range(self.log_size - 1):
ecd = getattr(self, self.names[i]) ecd = getattr(self, self.names[i])
inputs = ecd(inputs) inputs = ecd(inputs)
noise.append(inputs) noise.append(inputs)
inputs = inputs.reshape([inputs.shape[0], -1]) inputs = inputs.reshape([inputs.shape[0], -1])
outs = self.final_linear(inputs) outs = self.final_linear(inputs)
outs = self.generator([outs], return_latents, inject_index, truncation, noise = list(
truncation_latent, input_is_latent, itertools.chain.from_iterable(
noise=noise[::-1]) itertools.repeat(x, 2) for x in noise))[::-1]
outs = self.generator([outs],
return_latents,
inject_index,
truncation,
truncation_latent,
input_is_latent,
noise=noise[1:])
return outs return outs
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..modules.init import init_weights
from .criterions.IDLoss.id_loss import IDLoss
from paddle.nn import functional as F
from paddle import autograd
import math
def d_logistic_loss(real_pred, fake_pred):
real_loss = F.softplus(-real_pred)
fake_loss = F.softplus(fake_pred)
return real_loss.mean() + fake_loss.mean()
def d_r1_loss(real_pred, real_img):
grad_real, = autograd.grad(outputs=real_pred.sum(),
inputs=real_img,
create_graph=True)
grad_penalty = grad_real.pow(2).reshape([grad_real.shape[0],
-1]).sum(1).mean()
return grad_penalty
def g_nonsaturating_loss(fake_pred,
loss_funcs=None,
fake_img=None,
real_img=None,
input_img=None):
smooth_l1_loss, id_loss = loss_funcs
loss = F.softplus(-fake_pred).mean()
loss_l1 = smooth_l1_loss(fake_img, real_img)
loss_id = id_loss(fake_img, real_img, input_img)
loss += 1.0 * loss_l1 + 1.0 * loss_id
return loss
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = paddle.randn(fake_img.shape) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3])
grad, = autograd.grad(outputs=(fake_img * noise).sum(),
inputs=latents,
create_graph=True)
path_lengths = paddle.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() -
mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_mean.detach(), path_lengths
@MODELS.register()
class GPENModel(BaseModel):
""" This class implements the gpen model.
"""
def __init__(self, generator, discriminator=None, direction='a2b'):
super(GPENModel, self).__init__()
self.direction = direction
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(generator)
self.nets['g_ema'] = build_generator(generator)
self.nets['g_ema'].eval()
if discriminator:
self.nets['netD'] = build_discriminator(discriminator)
self.accum = 0.5**(32 / (10 * 1000))
self.mean_path_length = 0
self.gan_criterions = []
self.gan_criterions.append(paddle.nn.SmoothL1Loss())
self.gan_criterions.append(IDLoss())
self.current_iter = 0
def setup_input(self, input):
self.degraded_img = paddle.to_tensor(input[0])
self.real_img = paddle.to_tensor(input[1])
def forward(self, test_mode=False, regularize=False):
if test_mode:
self.fake_img, _ = self.nets['g_ema'](self.degraded_img) # G(A)
else:
if regularize:
self.fake_img, self.latents = self.nets['netG'](
self.degraded_img, return_latents=True)
else:
self.fake_img, _ = self.nets['netG'](self.degraded_img)
def backward_D(self, regularize=False):
"""Calculate GAN loss for the discriminator"""
if regularize:
self.real_img.stop_gradient = False
real_pred = self.nets['netD'](self.real_img)
r1_loss = d_r1_loss(real_pred, self.real_img)
(10 / 2 * r1_loss * 16).backward()
else:
fake_pred = self.nets['netD'](self.fake_img)
real_pred = self.nets['netD'](self.real_img)
self.loss_D = d_logistic_loss(real_pred, fake_pred)
self.loss_D.backward()
self.losses['D_loss'] = self.loss_D
def backward_G(self, regularize):
"""Calculate GAN and L1 loss for the generator"""
if regularize:
path_loss, self.mean_path_length, path_lengths = g_path_regularize(
self.fake_img, self.latents, self.mean_path_length)
weighted_path_loss = 2 * 4 * path_loss
weighted_path_loss.backward()
else:
fake_pred = self.nets['netD'](self.fake_img)
self.loss_G = g_nonsaturating_loss(fake_pred, self.gan_criterions,
self.fake_img, self.real_img,
self.degraded_img)
self.loss_G.backward()
self.losses['G_loss'] = self.loss_G
def train_iter(self, optimizers=None):
self.current_iter += 1
# update D
self.set_requires_grad(self.nets['netD'], True)
self.set_requires_grad(self.nets['netG'], False)
self.forward(test_mode=False)
optimizers['optimD'].clear_grad()
self.backward_D(regularize=False)
optimizers['optimD'].step()
d_regularize = self.current_iter % 24 == 0
if d_regularize:
optimizers['optimD'].clear_grad()
self.backward_D(regularize=True)
optimizers['optimD'].step()
# update G
self.set_requires_grad(self.nets['netD'], False)
self.set_requires_grad(self.nets['netG'], True)
self.forward(test_mode=False)
optimizers['optimG'].clear_grad()
self.backward_G(regularize=False)
optimizers['optimG'].step()
g_regularize = self.current_iter % 4 == 0
if g_regularize:
self.forward(test_mode=False, regularize=True)
optimizers['optimG'].clear_grad()
self.backward_G(regularize=True)
optimizers['optimG'].step()
self.accumulate(self.nets['g_ema'], self.nets['netG'], self.accum)
def test_iter(self, metrics=None):
self.nets['g_ema'].eval()
self.forward(test_mode=True)
with paddle.no_grad():
if metrics is not None:
for metric in metrics.values():
metric.update(self.fake_img, self.real_img)
def accumulate(self, model1, model2, decay=0.999):
par1 = dict(model1.state_dict())
par2 = dict(model2.state_dict())
for k in par1.keys():
par1[k] = par1[k] * decay + par2[k] * (1 - decay)
model1.load_dict(par1)
...@@ -28,6 +28,7 @@ class EqualConv2D(nn.Layer): ...@@ -28,6 +28,7 @@ class EqualConv2D(nn.Layer):
"""This convolutional layer class stabilizes the learning rate changes of its parameters. """This convolutional layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training. Equalizing learning rate keeps the weights in the network at a similar scale during training.
""" """
def __init__(self, def __init__(self,
in_channel, in_channel,
out_channel, out_channel,
...@@ -74,6 +75,7 @@ class EqualLinear(nn.Layer): ...@@ -74,6 +75,7 @@ class EqualLinear(nn.Layer):
"""This linear layer class stabilizes the learning rate changes of its parameters. """This linear layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training. Equalizing learning rate keeps the weights in the network at a similar scale during training.
""" """
def __init__(self, def __init__(self,
in_dim, in_dim,
out_dim, out_dim,
...@@ -115,3 +117,50 @@ class EqualLinear(nn.Layer): ...@@ -115,3 +117,50 @@ class EqualLinear(nn.Layer):
return ( return (
f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})" f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})"
) )
class EqualLinear_gpen(nn.Layer):
"""This linear layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training.
"""
def __init__(self,
in_dim,
out_dim,
bias=True,
bias_init=0,
lr_mul=1,
activation=None):
super().__init__()
self.weight = self.create_parameter(
(out_dim, in_dim), default_initializer=nn.initializer.Normal())
self.weight.set_value((self.weight / lr_mul))
if bias:
self.bias = self.create_parameter(
(out_dim, ), nn.initializer.Constant(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, (self.weight * self.scale).t())
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(input, (self.weight * self.scale).t(),
bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
)
===========================train_params===========================
model_name:GPEN
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10
output_dir:./output/
snapshot_config.interval:lite_train_lite_infer=10
pretrained_model:null
train_model_name:gpen*/*checkpoint.pdparams
train_infer_img_dir:null
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/gpen_256_ffhq.yaml --seed 100 -o log_config.interval=1
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/gpen_256_ffhq.yaml --inputs_size=1,3,256,256 --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/gpen/gpenmodel_g_ema
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type GPEN --seed 100 -c configs/gpen_256_ffhq.yaml --output_path test_tipc/output/ -o dataset.test.amount=5
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
...@@ -54,6 +54,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -54,6 +54,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
rm -rf ./data/ffhq* rm -rf ./data/ffhq*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../ ;; cd ./data/ && tar xf ffhq.tar && cd ../ ;;
GPEN)
rm -rf ./data/ffhq*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../ ;;
FOMM) FOMM)
rm -rf ./data/fom_lite* rm -rf ./data/fom_lite*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/fom_lite.tar --no-check-certificate --no-check-certificate wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/fom_lite.tar --no-check-certificate --no-check-certificate
...@@ -106,6 +110,10 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then ...@@ -106,6 +110,10 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
rm -rf ./data/ffhq* rm -rf ./data/ffhq*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../ cd ./data/ && tar xf ffhq.tar && cd ../
elif [ ${model_name} == "GPEN" ]; then
rm -rf ./data/ffhq*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/ffhq.tar --no-check-certificate
cd ./data/ && tar xf ffhq.tar && cd ../
elif [ ${model_name} == "basicvsr" ]; then elif [ ${model_name} == "basicvsr" ]; then
rm -rf ./data/reds* rm -rf ./data/reds*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/reds_lite.tar --no-check-certificate wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/reds_lite.tar --no-check-certificate
......
...@@ -16,6 +16,8 @@ import os ...@@ -16,6 +16,8 @@ import os
import sys import sys
import argparse import argparse
sys.path.append(".")
import ppgan import ppgan
from ppgan.utils.config import get_config from ppgan.utils.config import get_config
from ppgan.utils.setup import setup from ppgan.utils.setup import setup
...@@ -76,7 +78,7 @@ def main(args, cfg): ...@@ -76,7 +78,7 @@ def main(args, cfg):
for net_name, net in model.nets.items(): for net_name, net in model.nets.items():
if net_name in state_dicts: if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
model.export_model(cfg.export_model, args.output_dir, inputs_size, model.export_model(cfg.export_model, args.output_dir, inputs_size,
args.export_serving_model, args.model_name) args.export_serving_model, args.model_name)
......
...@@ -4,6 +4,10 @@ import numpy as np ...@@ -4,6 +4,10 @@ import numpy as np
import random import random
import os import os
from collections import OrderedDict from collections import OrderedDict
import sys
import cv2
sys.path.append(".")
from ppgan.utils.config import get_config from ppgan.utils.config import get_config
from ppgan.datasets.builder import build_dataloader from ppgan.datasets.builder import build_dataloader
...@@ -15,7 +19,7 @@ from ppgan.metrics import build_metric ...@@ -15,7 +19,7 @@ from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \ MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan","prenet"] "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan","prenet","GPEN"]
def parse_args(): def parse_args():
...@@ -313,7 +317,7 @@ def main(): ...@@ -313,7 +317,7 @@ def main():
metric_file = os.path.join(args.output_path, "singan/metric.txt") metric_file = os.path.join(args.output_path, "singan/metric.txt")
for metric in metrics.values(): for metric in metrics.values():
metric.update(prediction, data['A']) metric.update(prediction, data['A'])
elif model_type == "prenet": elif model_type == "prenet":
lq = data['lq'].numpy() lq = data['lq'].numpy()
gt = data['gt'].numpy() gt = data['gt'].numpy()
input_handles[0].copy_from_cpu(lq) input_handles[0].copy_from_cpu(lq)
...@@ -329,23 +333,31 @@ def main(): ...@@ -329,23 +333,31 @@ def main():
metric_file = os.path.join(args.output_path, "prenet/metric.txt") metric_file = os.path.join(args.output_path, "prenet/metric.txt")
for metric in metrics.values(): for metric in metrics.values():
metric.update(image_numpy, gt_img) metric.update(image_numpy, gt_img)
elif model_type == "GPEN":
elif model_type == "prenet": lq = data[0].numpy()
lq = data['lq'].numpy()
gt = data['gt'].numpy()
input_handles[0].copy_from_cpu(lq) input_handles[0].copy_from_cpu(lq)
predictor.run() predictor.run()
prediction = output_handle.copy_to_cpu() prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction) target = data[1].numpy()
gt = paddle.to_tensor(gt)
image_numpy = tensor2img(prediction, min_max) metric_file = os.path.join(args.output_path, model_type,
gt_img = tensor2img(gt, min_max) "metric.txt")
save_image(
image_numpy,
os.path.join(args.output_path, "prenet/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "prenet/metric.txt")
for metric in metrics.values(): for metric in metrics.values():
metric.update(image_numpy, gt_img) metric.update(prediction, target)
lq = paddle.to_tensor(lq)
target = paddle.to_tensor(target)
prediction = paddle.to_tensor(prediction)
lq = lq.transpose([0, 2, 3, 1])
target = target.transpose([0, 2, 3, 1])
prediction = prediction.transpose([0, 2, 3, 1])
sample_result = paddle.concat((lq[0], prediction[0], target[0]), 1)
sample = cv2.cvtColor((sample_result.numpy() + 1) / 2 * 255,
cv2.COLOR_RGB2BGR)
file_name = os.path.join(args.output_path, model_type,
"{}.png".format(i))
cv2.imwrite(file_name, sample)
if metrics: if metrics:
log_file = open(metric_file, 'a') log_file = open(metric_file, 'a')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册