未验证 提交 7e9fd3fc 编写于 作者: S simonsLiang 提交者: GitHub

Add PReNet to PaddleGan (#617)

* Create prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update README.md

* upload

* Delete test_tipc/configs/PReNet directory

* Update prenet_model.py

* Update ssim.py

* Update ssim.py

* Update prenet.yaml
上级 5df8fc28
...@@ -34,7 +34,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional ...@@ -34,7 +34,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
<div align='center'> <div align='center'>
<img src='https://user-images.githubusercontent.com/48054808/144848981-00c6ad21-0702-4381-9544-becb227ed9f0.gif' width='300'/> <img src='https://user-images.githubusercontent.com/48054808/144848981-00c6ad21-0702-4381-9544-becb227ed9f0.gif' width='300'/>
</div> </div>
- 😍 **Boy or Girl?:[StyleGAN V2 Face Editing](./docs/en_US/tutorials/styleganv2editing.md)-Changing genders!** 😍 - 😍 **Boy or Girl?:[StyleGAN V2 Face Editing](./docs/en_US/tutorials/styleganv2editing.md)-Changing genders!** 😍
- **[Online Toturials](https://aistudio.baidu.com/aistudio/projectdetail/2565277?contributionType=1)** - **[Online Toturials](https://aistudio.baidu.com/aistudio/projectdetail/2565277?contributionType=1)**
<div align='center'> <div align='center'>
...@@ -118,6 +118,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional ...@@ -118,6 +118,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* [StarGANv2](docs/en_US/tutorials/starganv2.md) * [StarGANv2](docs/en_US/tutorials/starganv2.md)
* [MPR Net](./docs/en_US/tutorials/mpr_net.md) * [MPR Net](./docs/en_US/tutorials/mpr_net.md)
* [FaceEnhancement](./docs/en_US/tutorials/face_enhancement.md) * [FaceEnhancement](./docs/en_US/tutorials/face_enhancement.md)
* [PReNet](./docs/en_US/tutorials/prenet.md)
## Composite Application ## Composite Application
......
...@@ -18,7 +18,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆) ...@@ -18,7 +18,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
## 近期活动🔥🔥🔥 ## 近期活动🔥🔥🔥
- 🔥**2021.12.08**🔥 - 🔥**2021.12.08**🔥
**💙 AI快车道👩‍🏫:视频超分算法及行业应用 💙** **💙 AI快车道👩‍🏫:视频超分算法及行业应用 💙**
- **课程回放链接🔗:https://aistudio.baidu.com/aistudio/education/group/info/25179** - **课程回放链接🔗:https://aistudio.baidu.com/aistudio/education/group/info/25179**
...@@ -140,6 +140,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆) ...@@ -140,6 +140,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
* 图像视频修复 * 图像视频修复
* 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md) * 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)
* 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md) * 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md)
* 图像去雨:[PReNet](./docs/zh_CN/tutorials/prenet.md)
## 产业级应用 ## 产业级应用
......
total_iters: 300000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: PReNetModel
generator:
name: PReNet
pixel_criterion:
name: SSIM
dataset:
train:
name: SRDataset
gt_folder: data/RainH/RainTrainH/norain
lq_folder: data/RainH/RainTrainH/rain
num_workers: 4
batch_size: 16
scale: 1
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: PairedRandomCrop
size: [100, 100]
keys: [image, image]
- name: PairedToTensor
keys: [image, image]
test:
name: SRDataset
gt_folder: data/RainH/Rain100H/norain
lq_folder: data/RainH/Rain100H/rain
scale: 1
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: PairedToTensor
keys: [image, image]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.0013
milestones: [36000,60000,96000]
gamma: 0.2
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: True
ssim:
name: SSIM
crop_border: 0
test_y_channel: True
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5000
export_model:
- {name: 'generator', inputs_num: 1}
# PReNet
## 1 Introduction
"Progressive Image Deraining Networks: A Better and Simpler Baseline" provides a better and simpler baseline deraining network by considering network architecture, input and output, and loss functions.
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/net.jpg" width=800">
</div>
## 2 How to use
### 2.1 Prepare dataset
The dataset(RainH.zip) used by PReNet can be downloaded from [here](https://pan.baidu.com/s/1_vxCatOV3sOA6Vkx1l23eA?pwd=vitu),uncompress it and get two folders(RainTrainH、Rain100H).
The structure of dataset is as following:
```
├── data
├── RainTrainH
├── rain
├── 1.png
└── 2.png
.
.
└── norain
├── 1.png
└── 2.png
.
.
└── Rain100H
├── rain
├── 001.png
└── 002.png
.
.
└── norain
├── 001.png
└── 002.png
.
.
```
### 2.2 Train/Test
train model:
```
python -u tools/main.py --config-file configs/prenet.yaml
```
test model:
```
python tools/main.py --config-file configs/prenet.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
Input:
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/rain-001.png" width=300">
</div>
Output:
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/derain-rain-001.png" width=300">
</div>
## 4 Model Download
| model | dataset |
|---|---|
| [PReNet](https://paddlegan.bj.bcebos.com/models/PReNet.pdparams) | [RainH.zip](https://pan.baidu.com/s/1_vxCatOV3sOA6Vkx1l23eA?pwd=vitu) |
# References
- 1. [Progressive Image Deraining Networks: A Better and Simpler Baseline](https://arxiv.org/pdf/1901.09221v3.pdf)
```
@inproceedings{ren2019progressive,
title={Progressive Image Deraining Networks: A Better and Simpler Baseline},
author={Ren, Dongwei and Zuo, Wangmeng and Hu, Qinghua and Zhu, Pengfei and Meng, Deyu},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2019},
}
```
# PReNet
## 1 简介
Progressive Image Deraining Networks: A Better and Simpler Baseline提出一种多阶段渐进的残差网络,每一个阶段都是resnet,每一res块的输入为上一res块输出和原始雨图,另外采用使用SSIM损失进行训练,进一步提升了网络的性能,网络总体简洁高效,在各种数据集上表现良好,为图像去雨提供了一个很好的基准。
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/net.jpg" width=800">
</div>
## 2 如何使用
### 2.1 数据准备
数据集(RainH.zip) 可以在[此处](https://pan.baidu.com/s/1_vxCatOV3sOA6Vkx1l23eA?pwd=vitu)下载,将其解压到./data路径下。
数据集文件结构如下:
```
├── data
├── RainTrainH
├── rain
├── 1.png
└── 2.png
.
.
└── norain
├── 1.png
└── 2.png
.
.
└── Rain100H
├── rain
├── 001.png
└── 002.png
.
.
└── norain
├── 001.png
└── 002.png
.
.
```
### 2.2 训练和测试
训练模型:
```
python -u tools/main.py --config-file configs/prenet.yaml
```
测试模型:
```
python tools/main.py --config-file configs/prenet.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 预测结果
输入:
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/rain-001.png" width=300">
</div>
输出:
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/derain-rain-001.png" width=300">
</div>
## 4 模型参数下载
| 模型 | 数据集 |
|---|---|
| [PReNet](https://paddlegan.bj.bcebos.com/models/PReNet.pdparams) | [RainH.zip](https://pan.baidu.com/s/1_vxCatOV3sOA6Vkx1l23eA?pwd=vitu) |
## 参考
- 1. [Progressive Image Deraining Networks: A Better and Simpler Baseline](https://arxiv.org/pdf/1901.09221v3.pdf)
```
@inproceedings{ren2019progressive,
title={Progressive Image Deraining Networks: A Better and Simpler Baseline},
author={Ren, Dongwei and Zuo, Wangmeng and Hu, Qinghua and Zhu, Pengfei and Meng, Deyu},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2019},
}
```
...@@ -3,6 +3,6 @@ from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip, ...@@ -3,6 +3,6 @@ from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip,
PairedRandomVerticalFlip, PairedRandomTransposeHW, PairedRandomVerticalFlip, PairedRandomTransposeHW,
SRPairedRandomCrop, SplitPairedImage, SRNoise, SRPairedRandomCrop, SplitPairedImage, SRNoise,
NormalizeSequence, MirrorVideoSequence, NormalizeSequence, MirrorVideoSequence,
TransposeSequence) TransposeSequence, PairedToTensor)
from .builder import build_preprocess from .builder import build_preprocess
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
from PIL import Image from PIL import Image
import paddle
import paddle.vision.transforms as T import paddle.vision.transforms as T
import paddle.vision.transforms.functional as F import paddle.vision.transforms.functional as F
...@@ -42,10 +43,12 @@ TRANSFORMS.register(T.RandomVerticalFlip) ...@@ -42,10 +43,12 @@ TRANSFORMS.register(T.RandomVerticalFlip)
TRANSFORMS.register(T.Normalize) TRANSFORMS.register(T.Normalize)
TRANSFORMS.register(T.Transpose) TRANSFORMS.register(T.Transpose)
TRANSFORMS.register(T.Grayscale) TRANSFORMS.register(T.Grayscale)
TRANSFORMS.register(T.ToTensor)
@PREPROCESS.register() @PREPROCESS.register()
class Transforms(): class Transforms():
def __init__(self, pipeline, input_keys, output_keys=None): def __init__(self, pipeline, input_keys, output_keys=None):
self.input_keys = input_keys self.input_keys = input_keys
self.output_keys = output_keys self.output_keys = output_keys
...@@ -81,6 +84,7 @@ class Transforms(): ...@@ -81,6 +84,7 @@ class Transforms():
@PREPROCESS.register() @PREPROCESS.register()
class SplitPairedImage: class SplitPairedImage:
def __init__(self, key, paired_keys=['A', 'B']): def __init__(self, key, paired_keys=['A', 'B']):
self.key = key self.key = key
self.paired_keys = paired_keys self.paired_keys = paired_keys
...@@ -103,6 +107,7 @@ class SplitPairedImage: ...@@ -103,6 +107,7 @@ class SplitPairedImage:
@TRANSFORMS.register() @TRANSFORMS.register()
class PairedRandomCrop(T.RandomCrop): class PairedRandomCrop(T.RandomCrop):
def __init__(self, size, keys=None): def __init__(self, size, keys=None):
super().__init__(size, keys=keys) super().__init__(size, keys=keys)
...@@ -122,8 +127,19 @@ class PairedRandomCrop(T.RandomCrop): ...@@ -122,8 +127,19 @@ class PairedRandomCrop(T.RandomCrop):
return F.crop(img, i, j, h, w) return F.crop(img, i, j, h, w)
@TRANSFORMS.register()
class PairedToTensor(T.ToTensor):
def __init__(self, data_format='CHW', keys=None):
super().__init__(data_format, keys=keys)
def _apply_image(self, img):
return F.to_tensor(img)
@TRANSFORMS.register() @TRANSFORMS.register()
class PairedRandomHorizontalFlip(T.RandomHorizontalFlip): class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None): def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys) super().__init__(prob, keys=keys)
...@@ -143,6 +159,7 @@ class PairedRandomHorizontalFlip(T.RandomHorizontalFlip): ...@@ -143,6 +159,7 @@ class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
@TRANSFORMS.register() @TRANSFORMS.register()
class PairedRandomVerticalFlip(T.RandomHorizontalFlip): class PairedRandomVerticalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None): def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys) super().__init__(prob, keys=keys)
...@@ -176,6 +193,7 @@ class PairedRandomTransposeHW(T.BaseTransform): ...@@ -176,6 +193,7 @@ class PairedRandomTransposeHW(T.BaseTransform):
prob (float): The propability to transpose the images. prob (float): The propability to transpose the images.
keys (list[str]): The images to be transposed. keys (list[str]): The images to be transposed.
""" """
def __init__(self, prob=0.5, keys=None): def __init__(self, prob=0.5, keys=None):
self.keys = keys self.keys = keys
self.prob = prob self.prob = prob
...@@ -220,6 +238,7 @@ class TransposeSequence(T.Transpose): ...@@ -220,6 +238,7 @@ class TransposeSequence(T.Transpose):
fake_img_seq = transform(fake_img_seq) fake_img_seq = transform(fake_img_seq)
""" """
def _apply_image(self, img): def _apply_image(self, img):
if isinstance(img, list): if isinstance(img, list):
imgs = [] imgs = []
...@@ -277,6 +296,7 @@ class NormalizeSequence(T.Normalize): ...@@ -277,6 +296,7 @@ class NormalizeSequence(T.Normalize):
fake_img_seq = normalize_seq(fake_img_seq) fake_img_seq = normalize_seq(fake_img_seq)
""" """
def _apply_image(self, img): def _apply_image(self, img):
if isinstance(img, list): if isinstance(img, list):
imgs = [ imgs = [
...@@ -302,6 +322,7 @@ class SRPairedRandomCrop(T.BaseTransform): ...@@ -302,6 +322,7 @@ class SRPairedRandomCrop(T.BaseTransform):
scale (int): model upscale factor. scale (int): model upscale factor.
gt_patch_size (int): cropped gt patch size. gt_patch_size (int): cropped gt patch size.
""" """
def __init__(self, scale, gt_patch_size, scale_list=False, keys=None): def __init__(self, scale, gt_patch_size, scale_list=False, keys=None):
self.gt_patch_size = gt_patch_size self.gt_patch_size = gt_patch_size
self.scale = scale self.scale = scale
...@@ -368,6 +389,7 @@ class SRNoise(T.BaseTransform): ...@@ -368,6 +389,7 @@ class SRNoise(T.BaseTransform):
noise_path (str): directory of noise image. noise_path (str): directory of noise image.
size (int): cropped noise patch size. size (int): cropped noise patch size.
""" """
def __init__(self, noise_path, size, keys=None): def __init__(self, noise_path, size, keys=None):
self.noise_path = noise_path self.noise_path = noise_path
self.noise_imgs = sorted(glob.glob(noise_path + '*.png')) self.noise_imgs = sorted(glob.glob(noise_path + '*.png'))
...@@ -396,6 +418,7 @@ class RandomResizedCropProb(T.RandomResizedCrop): ...@@ -396,6 +418,7 @@ class RandomResizedCropProb(T.RandomResizedCrop):
prob (float): probabilty of using random-resized cropping. prob (float): probabilty of using random-resized cropping.
size (int): cropped size. size (int): cropped size.
""" """
def __init__(self, prob, size, scale, ratio, interpolation, keys=None): def __init__(self, prob, size, scale, ratio, interpolation, keys=None):
super().__init__(size, scale, ratio, interpolation) super().__init__(size, scale, ratio, interpolation)
self.prob = prob self.prob = prob
...@@ -409,6 +432,7 @@ class RandomResizedCropProb(T.RandomResizedCrop): ...@@ -409,6 +432,7 @@ class RandomResizedCropProb(T.RandomResizedCrop):
@TRANSFORMS.register() @TRANSFORMS.register()
class Add(T.BaseTransform): class Add(T.BaseTransform):
def __init__(self, value, keys=None): def __init__(self, value, keys=None):
"""Initialize Add Transform """Initialize Add Transform
...@@ -430,6 +454,7 @@ class Add(T.BaseTransform): ...@@ -430,6 +454,7 @@ class Add(T.BaseTransform):
@TRANSFORMS.register() @TRANSFORMS.register()
class ResizeToScale(T.BaseTransform): class ResizeToScale(T.BaseTransform):
def __init__(self, def __init__(self,
size: int, size: int,
scale: int, scale: int,
...@@ -480,6 +505,7 @@ class ResizeToScale(T.BaseTransform): ...@@ -480,6 +505,7 @@ class ResizeToScale(T.BaseTransform):
@TRANSFORMS.register() @TRANSFORMS.register()
class PairedColorJitter(T.BaseTransform): class PairedColorJitter(T.BaseTransform):
def __init__(self, def __init__(self,
brightness=0, brightness=0,
contrast=0, contrast=0,
...@@ -545,6 +571,7 @@ class MirrorVideoSequence: ...@@ -545,6 +571,7 @@ class MirrorVideoSequence:
Args: Args:
keys (list[str]): The frame lists to be extended. keys (list[str]): The frame lists to be extended.
""" """
def __init__(self, keys=None): def __init__(self, keys=None):
self.keys = keys self.keys = keys
......
...@@ -35,3 +35,4 @@ from .mpr_model import MPRModel ...@@ -35,3 +35,4 @@ from .mpr_model import MPRModel
from .photopen_model import PhotoPenModel from .photopen_model import PhotoPenModel
from .msvsr_model import MultiStageVSRModel from .msvsr_model import MultiStageVSRModel
from .singan_model import SinGANModel from .singan_model import SinGANModel
from .prenet_model import PReNetModel
...@@ -7,3 +7,5 @@ from .photopen_perceptual_loss import PhotoPenPerceptualLoss ...@@ -7,3 +7,5 @@ from .photopen_perceptual_loss import PhotoPenPerceptualLoss
from .gradient_penalty import GradientPenalty from .gradient_penalty import GradientPenalty
from .builder import build_criterion from .builder import build_criterion
from .ssim import SSIM
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/csdwren/PReNet
# Users should be careful about adopting these functions in any commercial matters.
import numpy as np
from math import exp
import paddle
import paddle.nn.functional as F
from .builder import CRITERIONS
def gaussian(window_size, sigma):
gauss = paddle.to_tensor([
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
window = paddle.to_tensor(paddle.expand(
_2D_window, (channel, 1, window_size, window_size)),
stop_gradient=False)
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) *
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
@CRITERIONS.register()
class SSIM(paddle.nn.Layer):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.shape
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel)
tt = img1.dtype
window = paddle.to_tensor(window, dtype=tt)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel,
self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
...@@ -39,3 +39,4 @@ from .generater_photopen import SPADEGenerator ...@@ -39,3 +39,4 @@ from .generater_photopen import SPADEGenerator
from .basicvsr_plus_plus import BasicVSRPlusPlus from .basicvsr_plus_plus import BasicVSRPlusPlus
from .msvsr import MSVSR from .msvsr import MSVSR
from .generator_singan import SinGANGenerator from .generator_singan import SinGANGenerator
from .prenet import PReNet
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/csdwren/PReNet
# Users should be careful about adopting these functions in any commercial matters.
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
def convWithBias(in_channels, out_channels, kernel_size, stride, padding):
""" Obtain a 2d convolution layer with bias and initialized by KaimingUniform
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Convolution kernel size
stride (int): Convolution stride
padding (int|tuple): Convolution padding.
"""
if isinstance(kernel_size, int):
fan_in = kernel_size * kernel_size * in_channels
else:
fan_in = kernel_size[0] * kernel_size[1] * in_channels
bound = 1 / math.sqrt(fan_in)
bias_attr = paddle.framework.ParamAttr(
initializer=nn.initializer.Uniform(-bound, bound))
weight_attr = paddle.framework.ParamAttr(
initializer=nn.initializer.KaimingUniform(fan_in=6 * fan_in))
conv = nn.Conv2D(in_channels,
out_channels,
kernel_size,
stride,
padding,
weight_attr=weight_attr,
bias_attr=bias_attr)
return conv
@GENERATORS.register()
class PReNet(nn.Layer):
"""
Args:
recurrent_iter (int): Number of iterations.
Default: 6.
use_GPU (bool): whether use gpu or not .
Default: True.
"""
def __init__(self, recurrent_iter=6, use_GPU=True):
super(PReNet, self).__init__()
self.iteration = recurrent_iter
self.use_GPU = use_GPU
self.conv0 = nn.Sequential(convWithBias(6, 32, 3, 1, 1), nn.ReLU())
self.res_conv1 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.res_conv2 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.res_conv3 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.res_conv4 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.res_conv5 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.conv_i = nn.Sequential(convWithBias(32 + 32, 32, 3, 1, 1),
nn.Sigmoid())
self.conv_f = nn.Sequential(convWithBias(32 + 32, 32, 3, 1, 1),
nn.Sigmoid())
self.conv_g = nn.Sequential(convWithBias(32 + 32, 32, 3, 1, 1),
nn.Tanh())
self.conv_o = nn.Sequential(convWithBias(32 + 32, 32, 3, 1, 1),
nn.Sigmoid())
self.conv = nn.Sequential(convWithBias(32, 3, 3, 1, 1), )
def forward(self, input):
batch_size, row, col = input.shape[0], input.shape[2], input.shape[3]
x = input
h = paddle.to_tensor(paddle.zeros(shape=(batch_size, 32, row, col),
dtype='float32'),
stop_gradient=False)
c = paddle.to_tensor(paddle.zeros(shape=(batch_size, 32, row, col),
dtype='float32'),
stop_gradient=False)
x_list = []
for _ in range(self.iteration):
x = paddle.concat((input, x), 1)
x = self.conv0(x)
x = paddle.concat((x, h), 1)
i = self.conv_i(x)
f = self.conv_f(x)
g = self.conv_g(x)
o = self.conv_o(x)
c = f * c + i * g
h = o * paddle.tanh(c)
x = h
resx = x
x = F.relu(self.res_conv1(x) + resx)
resx = x
x = F.relu(self.res_conv2(x) + resx)
resx = x
x = F.relu(self.res_conv3(x) + resx)
resx = x
x = F.relu(self.res_conv4(x) + resx)
resx = x
x = F.relu(self.res_conv5(x) + resx)
x = self.conv(x)
x = x + input
x_list.append(x)
return x
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .builder import MODELS
from .sr_model import BaseSRModel
from .generators.iconvsr import EDVRFeatureExtractor
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet
from ..modules.init import reset_parameters
from ..utils.visual import tensor2img
@MODELS.register()
class PReNetModel(BaseSRModel):
"""PReNet Model.
Paper: Progressive Image Deraining Networks: A Better and Simpler Baseline, IEEE,2019
"""
def __init__(self, generator, pixel_criterion=None):
"""Initialize the BasicVSR class.
Args:
generator (dict): config of generator.
fix_iter (dict): config of fix_iter.
pixel_criterion (dict): config of pixel criterion.
"""
super(PReNetModel, self).__init__(generator, pixel_criterion)
self.current_iter = 1
self.flag = True
def setup_input(self, input):
self.lq = input['lq']
self.visual_items['lq'] = self.lq[0, :, :, :]
if 'gt' in input:
self.gt = input['gt']
self.visual_items['gt'] = self.gt[0, :, :, :]
self.image_paths = input['lq_path']
def train_iter(self, optims=None):
optims['optim'].clear_grad()
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output[0, :, :, :]
# pixel loss
loss_pixel = -self.pixel_criterion(self.output, self.gt)
loss_pixel.backward()
optims['optim'].step()
self.losses['loss_pixel'] = loss_pixel
self.current_iter += 1
def test_iter(self, metrics=None):
self.gt = self.gt.cpu()
self.nets['generator'].eval()
with paddle.no_grad():
output = self.nets['generator'](self.lq)
self.visual_items['output'] = output[0, :, :, :].cpu()
self.nets['generator'].train()
out_img = []
gt_img = []
out_tensor = output[0]
gt_tensor = self.gt[0]
out_img = tensor2img(out_tensor, (0., 1.))
gt_img = tensor2img(gt_tensor, (0., 1.))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img, is_seq=True)
...@@ -15,7 +15,7 @@ from ppgan.metrics import build_metric ...@@ -15,7 +15,7 @@ from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \ MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan"] "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan","prenet"]
def parse_args(): def parse_args():
...@@ -56,63 +56,55 @@ def parse_args(): ...@@ -56,63 +56,55 @@ def parse_args():
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=None, default=None,
help='fix random numbers by setting seed\".' help='fix random numbers by setting seed\".')
)
# for tensorRT # for tensorRT
parser.add_argument( parser.add_argument("--run_mode",
"--run_mode", default="fluid",
default="fluid", type=str,
type=str, choices=["fluid", "trt_fp32", "trt_fp16"],
choices=["fluid", "trt_fp32", "trt_fp16"], help="mode of running(fluid/trt_fp32/trt_fp16)")
help="mode of running(fluid/trt_fp32/trt_fp16)") parser.add_argument("--trt_min_shape",
parser.add_argument( default=1,
"--trt_min_shape", type=int,
default=1, help="trt_min_shape for tensorRT")
type=int, parser.add_argument("--trt_max_shape",
help="trt_min_shape for tensorRT") default=1280,
parser.add_argument( type=int,
"--trt_max_shape", help="trt_max_shape for tensorRT")
default=1280, parser.add_argument("--trt_opt_shape",
type=int, default=640,
help="trt_max_shape for tensorRT") type=int,
parser.add_argument( help="trt_opt_shape for tensorRT")
"--trt_opt_shape", parser.add_argument("--min_subgraph_size",
default=640, default=3,
type=int, type=int,
help="trt_opt_shape for tensorRT") help="trt_opt_shape for tensorRT")
parser.add_argument( parser.add_argument("--batch_size",
"--min_subgraph_size", default=1,
default=3, type=int,
type=int, help="batch_size for tensorRT")
help="trt_opt_shape for tensorRT") parser.add_argument("--use_dynamic_shape",
parser.add_argument( dest="use_dynamic_shape",
"--batch_size", action="store_true",
default=1, help="use_dynamic_shape for tensorRT")
type=int, parser.add_argument("--trt_calib_mode",
help="batch_size for tensorRT") dest="trt_calib_mode",
parser.add_argument( action="store_true",
"--use_dynamic_shape", help="trt_calib_mode for tensorRT")
dest="use_dynamic_shape",
action="store_true",
help="use_dynamic_shape for tensorRT")
parser.add_argument(
"--trt_calib_mode",
dest="trt_calib_mode",
action="store_true",
help="trt_calib_mode for tensorRT")
args = parser.parse_args() args = parser.parse_args()
return args return args
def create_predictor(model_path, device="gpu", def create_predictor(model_path,
run_mode='fluid', device="gpu",
batch_size=1, run_mode='fluid',
min_subgraph_size=3, batch_size=1,
use_dynamic_shape=False, min_subgraph_size=3,
trt_min_shape=1, use_dynamic_shape=False,
trt_max_shape=1280, trt_min_shape=1,
trt_opt_shape=640, trt_max_shape=1280,
trt_calib_mode=False): trt_opt_shape=640,
trt_calib_mode=False):
config = paddle.inference.Config(model_path + ".pdmodel", config = paddle.inference.Config(model_path + ".pdmodel",
model_path + ".pdiparams") model_path + ".pdiparams")
if device == "gpu": if device == "gpu":
...@@ -123,20 +115,19 @@ def create_predictor(model_path, device="gpu", ...@@ -123,20 +115,19 @@ def create_predictor(model_path, device="gpu",
config.enable_xpu(100) config.enable_xpu(100)
else: else:
config.disable_gpu() config.disable_gpu()
precision_map = { precision_map = {
'trt_int8': paddle.inference.Config.Precision.Int8, 'trt_int8': paddle.inference.Config.Precision.Int8,
'trt_fp32': paddle.inference.Config.Precision.Float32, 'trt_fp32': paddle.inference.Config.Precision.Float32,
'trt_fp16': paddle.inference.Config.Precision.Half 'trt_fp16': paddle.inference.Config.Precision.Half
} }
if run_mode in precision_map.keys(): if run_mode in precision_map.keys():
config.enable_tensorrt_engine( config.enable_tensorrt_engine(workspace_size=1 << 25,
workspace_size=1 << 25, max_batch_size=batch_size,
max_batch_size=batch_size, min_subgraph_size=min_subgraph_size,
min_subgraph_size=min_subgraph_size, precision_mode=precision_map[run_mode],
precision_mode=precision_map[run_mode], use_static=False,
use_static=False, use_calib_mode=trt_calib_mode)
use_calib_mode=trt_calib_mode)
if use_dynamic_shape: if use_dynamic_shape:
min_input_shape = { min_input_shape = {
...@@ -155,6 +146,7 @@ def create_predictor(model_path, device="gpu", ...@@ -155,6 +146,7 @@ def create_predictor(model_path, device="gpu",
predictor = paddle.inference.create_predictor(config) predictor = paddle.inference.create_predictor(config)
return predictor return predictor
def setup_metrics(cfg): def setup_metrics(cfg):
metrics = OrderedDict() metrics = OrderedDict()
if isinstance(list(cfg.values())[0], dict): if isinstance(list(cfg.values())[0], dict):
...@@ -166,22 +158,18 @@ def setup_metrics(cfg): ...@@ -166,22 +158,18 @@ def setup_metrics(cfg):
return metrics return metrics
def main(): def main():
args = parse_args() args = parse_args()
if args.seed: if args.seed:
paddle.seed(args.seed) paddle.seed(args.seed)
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
cfg = get_config(args.config_file, args.opt) cfg = get_config(args.config_file, args.opt)
predictor = create_predictor(args.model_path, predictor = create_predictor(args.model_path, args.device, args.run_mode,
args.device, args.batch_size, args.min_subgraph_size,
args.run_mode, args.use_dynamic_shape, args.trt_min_shape,
args.batch_size, args.trt_max_shape, args.trt_opt_shape,
args.min_subgraph_size,
args.use_dynamic_shape,
args.trt_min_shape,
args.trt_max_shape,
args.trt_opt_shape,
args.trt_calib_mode) args.trt_calib_mode)
input_handles = [ input_handles = [
predictor.get_input_handle(name) predictor.get_input_handle(name)
...@@ -218,7 +206,9 @@ def main(): ...@@ -218,7 +206,9 @@ def main():
prediction = output_handle.copy_to_cpu() prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction) prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max) image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "pix2pix/{}.png".format(i))) save_image(
image_numpy,
os.path.join(args.output_path, "pix2pix/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "pix2pix/metric.txt") metric_file = os.path.join(args.output_path, "pix2pix/metric.txt")
real_B = paddle.to_tensor(data['A']) real_B = paddle.to_tensor(data['A'])
for metric in metrics.values(): for metric in metrics.values():
...@@ -231,7 +221,9 @@ def main(): ...@@ -231,7 +221,9 @@ def main():
prediction = output_handle.copy_to_cpu() prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction) prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max) image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "cyclegan/{}.png".format(i))) save_image(
image_numpy,
os.path.join(args.output_path, "cyclegan/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "cyclegan/metric.txt") metric_file = os.path.join(args.output_path, "cyclegan/metric.txt")
real_B = paddle.to_tensor(data['B']) real_B = paddle.to_tensor(data['B'])
for metric in metrics.values(): for metric in metrics.values():
...@@ -275,7 +267,9 @@ def main(): ...@@ -275,7 +267,9 @@ def main():
prediction = output_handle.copy_to_cpu() prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction) prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max) image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "stylegan2/{}.png".format(i))) save_image(
image_numpy,
os.path.join(args.output_path, "stylegan2/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "stylegan2/metric.txt") metric_file = os.path.join(args.output_path, "stylegan2/metric.txt")
real_img = paddle.to_tensor(data['A']) real_img = paddle.to_tensor(data['A'])
for metric in metrics.values(): for metric in metrics.values():
...@@ -285,7 +279,8 @@ def main(): ...@@ -285,7 +279,8 @@ def main():
input_handles[0].copy_from_cpu(lq) input_handles[0].copy_from_cpu(lq)
predictor.run() predictor.run()
if len(predictor.get_output_names()) > 1: if len(predictor.get_output_names()) > 1:
output_handle = predictor.get_output_handle(predictor.get_output_names()[-1]) output_handle = predictor.get_output_handle(
predictor.get_output_names()[-1])
prediction = output_handle.copy_to_cpu() prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction) prediction = paddle.to_tensor(prediction)
_, t, _, _, _ = prediction.shape _, t, _, _, _ = prediction.shape
...@@ -295,13 +290,16 @@ def main(): ...@@ -295,13 +290,16 @@ def main():
for ti in range(t): for ti in range(t):
out_tensor = prediction[0, ti] out_tensor = prediction[0, ti]
gt_tensor = data['gt'][0, ti] gt_tensor = data['gt'][0, ti]
out_img.append(tensor2img(out_tensor, (0.,1.))) out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0.,1.))) gt_img.append(tensor2img(gt_tensor, (0., 1.)))
image_numpy = tensor2img(prediction[0], min_max) image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, model_type, "{}.png".format(i))) save_image(
image_numpy,
os.path.join(args.output_path, model_type, "{}.png".format(i)))
metric_file = os.path.join(args.output_path, model_type, "metric.txt") metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values(): for metric in metrics.values():
metric.update(out_img, gt_img, is_seq=True) metric.update(out_img, gt_img, is_seq=True)
elif model_type == "singan": elif model_type == "singan":
...@@ -309,18 +307,38 @@ def main(): ...@@ -309,18 +307,38 @@ def main():
prediction = output_handle.copy_to_cpu() prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction) prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction, min_max) image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, os.path.join(args.output_path, "singan/{}.png".format(i))) save_image(
image_numpy,
os.path.join(args.output_path, "singan/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "singan/metric.txt") metric_file = os.path.join(args.output_path, "singan/metric.txt")
for metric in metrics.values(): for metric in metrics.values():
metric.update(prediction, data['A']) metric.update(prediction, data['A'])
elif model_type == "prenet":
lq = data['lq'].numpy()
gt = data['gt'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
gt = paddle.to_tensor(gt)
image_numpy = tensor2img(prediction, min_max)
gt_img = tensor2img(gt, min_max)
save_image(
image_numpy,
os.path.join(args.output_path, "prenet/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "prenet/metric.txt")
for metric in metrics.values():
metric.update(image_numpy, gt_img)
if metrics: if metrics:
log_file = open(metric_file, 'a') log_file = open(metric_file, 'a')
for metric_name, metric in metrics.items(): for metric_name, metric in metrics.items():
loss_string = "Metric {}: {:.4f}".format( loss_string = "Metric {}: {:.4f}".format(metric_name,
metric_name, metric.accumulate()) metric.accumulate())
print(loss_string, file=log_file) print(loss_string, file=log_file)
log_file.close() log_file.close()
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册