未验证 提交 7e9fd3fc 编写于 作者: S simonsLiang 提交者: GitHub

Add PReNet to PaddleGan (#617)

* Create prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update prenet.md

* Update README.md

* upload

* Delete test_tipc/configs/PReNet directory

* Update prenet_model.py

* Update ssim.py

* Update ssim.py

* Update prenet.yaml
上级 5df8fc28
......@@ -34,7 +34,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
<div align='center'>
<img src='https://user-images.githubusercontent.com/48054808/144848981-00c6ad21-0702-4381-9544-becb227ed9f0.gif' width='300'/>
</div>
- 😍 **Boy or Girl?:[StyleGAN V2 Face Editing](./docs/en_US/tutorials/styleganv2editing.md)-Changing genders!** 😍
- **[Online Toturials](https://aistudio.baidu.com/aistudio/projectdetail/2565277?contributionType=1)**
<div align='center'>
......@@ -118,6 +118,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* [StarGANv2](docs/en_US/tutorials/starganv2.md)
* [MPR Net](./docs/en_US/tutorials/mpr_net.md)
* [FaceEnhancement](./docs/en_US/tutorials/face_enhancement.md)
* [PReNet](./docs/en_US/tutorials/prenet.md)
## Composite Application
......
......@@ -18,7 +18,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
## 近期活动🔥🔥🔥
- 🔥**2021.12.08**🔥
**💙 AI快车道👩‍🏫:视频超分算法及行业应用 💙**
- **课程回放链接🔗:https://aistudio.baidu.com/aistudio/education/group/info/25179**
......@@ -140,6 +140,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
* 图像视频修复
* 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)
* 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md)
* 图像去雨:[PReNet](./docs/zh_CN/tutorials/prenet.md)
## 产业级应用
......
total_iters: 300000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: PReNetModel
generator:
name: PReNet
pixel_criterion:
name: SSIM
dataset:
train:
name: SRDataset
gt_folder: data/RainH/RainTrainH/norain
lq_folder: data/RainH/RainTrainH/rain
num_workers: 4
batch_size: 16
scale: 1
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: PairedRandomCrop
size: [100, 100]
keys: [image, image]
- name: PairedToTensor
keys: [image, image]
test:
name: SRDataset
gt_folder: data/RainH/Rain100H/norain
lq_folder: data/RainH/Rain100H/rain
scale: 1
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: PairedToTensor
keys: [image, image]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.0013
milestones: [36000,60000,96000]
gamma: 0.2
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: True
ssim:
name: SSIM
crop_border: 0
test_y_channel: True
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5000
export_model:
- {name: 'generator', inputs_num: 1}
# PReNet
## 1 Introduction
"Progressive Image Deraining Networks: A Better and Simpler Baseline" provides a better and simpler baseline deraining network by considering network architecture, input and output, and loss functions.
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/net.jpg" width=800">
</div>
## 2 How to use
### 2.1 Prepare dataset
The dataset(RainH.zip) used by PReNet can be downloaded from [here](https://pan.baidu.com/s/1_vxCatOV3sOA6Vkx1l23eA?pwd=vitu),uncompress it and get two folders(RainTrainH、Rain100H).
The structure of dataset is as following:
```
├── data
├── RainTrainH
├── rain
├── 1.png
└── 2.png
.
.
└── norain
├── 1.png
└── 2.png
.
.
└── Rain100H
├── rain
├── 001.png
└── 002.png
.
.
└── norain
├── 001.png
└── 002.png
.
.
```
### 2.2 Train/Test
train model:
```
python -u tools/main.py --config-file configs/prenet.yaml
```
test model:
```
python tools/main.py --config-file configs/prenet.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
Input:
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/rain-001.png" width=300">
</div>
Output:
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/derain-rain-001.png" width=300">
</div>
## 4 Model Download
| model | dataset |
|---|---|
| [PReNet](https://paddlegan.bj.bcebos.com/models/PReNet.pdparams) | [RainH.zip](https://pan.baidu.com/s/1_vxCatOV3sOA6Vkx1l23eA?pwd=vitu) |
# References
- 1. [Progressive Image Deraining Networks: A Better and Simpler Baseline](https://arxiv.org/pdf/1901.09221v3.pdf)
```
@inproceedings{ren2019progressive,
title={Progressive Image Deraining Networks: A Better and Simpler Baseline},
author={Ren, Dongwei and Zuo, Wangmeng and Hu, Qinghua and Zhu, Pengfei and Meng, Deyu},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2019},
}
```
# PReNet
## 1 简介
Progressive Image Deraining Networks: A Better and Simpler Baseline提出一种多阶段渐进的残差网络,每一个阶段都是resnet,每一res块的输入为上一res块输出和原始雨图,另外采用使用SSIM损失进行训练,进一步提升了网络的性能,网络总体简洁高效,在各种数据集上表现良好,为图像去雨提供了一个很好的基准。
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/net.jpg" width=800">
</div>
## 2 如何使用
### 2.1 数据准备
数据集(RainH.zip) 可以在[此处](https://pan.baidu.com/s/1_vxCatOV3sOA6Vkx1l23eA?pwd=vitu)下载,将其解压到./data路径下。
数据集文件结构如下:
```
├── data
├── RainTrainH
├── rain
├── 1.png
└── 2.png
.
.
└── norain
├── 1.png
└── 2.png
.
.
└── Rain100H
├── rain
├── 001.png
└── 002.png
.
.
└── norain
├── 001.png
└── 002.png
.
.
```
### 2.2 训练和测试
训练模型:
```
python -u tools/main.py --config-file configs/prenet.yaml
```
测试模型:
```
python tools/main.py --config-file configs/prenet.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 预测结果
输入:
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/rain-001.png" width=300">
</div>
输出:
<div align="center">
<img src="https://github.com/simonsLiang/PReNet_paddle/blob/main/data/derain-rain-001.png" width=300">
</div>
## 4 模型参数下载
| 模型 | 数据集 |
|---|---|
| [PReNet](https://paddlegan.bj.bcebos.com/models/PReNet.pdparams) | [RainH.zip](https://pan.baidu.com/s/1_vxCatOV3sOA6Vkx1l23eA?pwd=vitu) |
## 参考
- 1. [Progressive Image Deraining Networks: A Better and Simpler Baseline](https://arxiv.org/pdf/1901.09221v3.pdf)
```
@inproceedings{ren2019progressive,
title={Progressive Image Deraining Networks: A Better and Simpler Baseline},
author={Ren, Dongwei and Zuo, Wangmeng and Hu, Qinghua and Zhu, Pengfei and Meng, Deyu},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2019},
}
```
......@@ -3,6 +3,6 @@ from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip,
PairedRandomVerticalFlip, PairedRandomTransposeHW,
SRPairedRandomCrop, SplitPairedImage, SRNoise,
NormalizeSequence, MirrorVideoSequence,
TransposeSequence)
TransposeSequence, PairedToTensor)
from .builder import build_preprocess
......@@ -22,6 +22,7 @@ import numpy as np
from PIL import Image
import paddle
import paddle.vision.transforms as T
import paddle.vision.transforms.functional as F
......@@ -42,10 +43,12 @@ TRANSFORMS.register(T.RandomVerticalFlip)
TRANSFORMS.register(T.Normalize)
TRANSFORMS.register(T.Transpose)
TRANSFORMS.register(T.Grayscale)
TRANSFORMS.register(T.ToTensor)
@PREPROCESS.register()
class Transforms():
def __init__(self, pipeline, input_keys, output_keys=None):
self.input_keys = input_keys
self.output_keys = output_keys
......@@ -81,6 +84,7 @@ class Transforms():
@PREPROCESS.register()
class SplitPairedImage:
def __init__(self, key, paired_keys=['A', 'B']):
self.key = key
self.paired_keys = paired_keys
......@@ -103,6 +107,7 @@ class SplitPairedImage:
@TRANSFORMS.register()
class PairedRandomCrop(T.RandomCrop):
def __init__(self, size, keys=None):
super().__init__(size, keys=keys)
......@@ -122,8 +127,19 @@ class PairedRandomCrop(T.RandomCrop):
return F.crop(img, i, j, h, w)
@TRANSFORMS.register()
class PairedToTensor(T.ToTensor):
def __init__(self, data_format='CHW', keys=None):
super().__init__(data_format, keys=keys)
def _apply_image(self, img):
return F.to_tensor(img)
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys)
......@@ -143,6 +159,7 @@ class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
@TRANSFORMS.register()
class PairedRandomVerticalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys)
......@@ -176,6 +193,7 @@ class PairedRandomTransposeHW(T.BaseTransform):
prob (float): The propability to transpose the images.
keys (list[str]): The images to be transposed.
"""
def __init__(self, prob=0.5, keys=None):
self.keys = keys
self.prob = prob
......@@ -220,6 +238,7 @@ class TransposeSequence(T.Transpose):
fake_img_seq = transform(fake_img_seq)
"""
def _apply_image(self, img):
if isinstance(img, list):
imgs = []
......@@ -277,6 +296,7 @@ class NormalizeSequence(T.Normalize):
fake_img_seq = normalize_seq(fake_img_seq)
"""
def _apply_image(self, img):
if isinstance(img, list):
imgs = [
......@@ -302,6 +322,7 @@ class SRPairedRandomCrop(T.BaseTransform):
scale (int): model upscale factor.
gt_patch_size (int): cropped gt patch size.
"""
def __init__(self, scale, gt_patch_size, scale_list=False, keys=None):
self.gt_patch_size = gt_patch_size
self.scale = scale
......@@ -368,6 +389,7 @@ class SRNoise(T.BaseTransform):
noise_path (str): directory of noise image.
size (int): cropped noise patch size.
"""
def __init__(self, noise_path, size, keys=None):
self.noise_path = noise_path
self.noise_imgs = sorted(glob.glob(noise_path + '*.png'))
......@@ -396,6 +418,7 @@ class RandomResizedCropProb(T.RandomResizedCrop):
prob (float): probabilty of using random-resized cropping.
size (int): cropped size.
"""
def __init__(self, prob, size, scale, ratio, interpolation, keys=None):
super().__init__(size, scale, ratio, interpolation)
self.prob = prob
......@@ -409,6 +432,7 @@ class RandomResizedCropProb(T.RandomResizedCrop):
@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
"""Initialize Add Transform
......@@ -430,6 +454,7 @@ class Add(T.BaseTransform):
@TRANSFORMS.register()
class ResizeToScale(T.BaseTransform):
def __init__(self,
size: int,
scale: int,
......@@ -480,6 +505,7 @@ class ResizeToScale(T.BaseTransform):
@TRANSFORMS.register()
class PairedColorJitter(T.BaseTransform):
def __init__(self,
brightness=0,
contrast=0,
......@@ -545,6 +571,7 @@ class MirrorVideoSequence:
Args:
keys (list[str]): The frame lists to be extended.
"""
def __init__(self, keys=None):
self.keys = keys
......
......@@ -35,3 +35,4 @@ from .mpr_model import MPRModel
from .photopen_model import PhotoPenModel
from .msvsr_model import MultiStageVSRModel
from .singan_model import SinGANModel
from .prenet_model import PReNetModel
......@@ -7,3 +7,5 @@ from .photopen_perceptual_loss import PhotoPenPerceptualLoss
from .gradient_penalty import GradientPenalty
from .builder import build_criterion
from .ssim import SSIM
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/csdwren/PReNet
# Users should be careful about adopting these functions in any commercial matters.
import numpy as np
from math import exp
import paddle
import paddle.nn.functional as F
from .builder import CRITERIONS
def gaussian(window_size, sigma):
gauss = paddle.to_tensor([
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
window = paddle.to_tensor(paddle.expand(
_2D_window, (channel, 1, window_size, window_size)),
stop_gradient=False)
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) *
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
@CRITERIONS.register()
class SSIM(paddle.nn.Layer):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.shape
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel)
tt = img1.dtype
window = paddle.to_tensor(window, dtype=tt)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel,
self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
......@@ -39,3 +39,4 @@ from .generater_photopen import SPADEGenerator
from .basicvsr_plus_plus import BasicVSRPlusPlus
from .msvsr import MSVSR
from .generator_singan import SinGANGenerator
from .prenet import PReNet
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/csdwren/PReNet
# Users should be careful about adopting these functions in any commercial matters.
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
def convWithBias(in_channels, out_channels, kernel_size, stride, padding):
""" Obtain a 2d convolution layer with bias and initialized by KaimingUniform
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Convolution kernel size
stride (int): Convolution stride
padding (int|tuple): Convolution padding.
"""
if isinstance(kernel_size, int):
fan_in = kernel_size * kernel_size * in_channels
else:
fan_in = kernel_size[0] * kernel_size[1] * in_channels
bound = 1 / math.sqrt(fan_in)
bias_attr = paddle.framework.ParamAttr(
initializer=nn.initializer.Uniform(-bound, bound))
weight_attr = paddle.framework.ParamAttr(
initializer=nn.initializer.KaimingUniform(fan_in=6 * fan_in))
conv = nn.Conv2D(in_channels,
out_channels,
kernel_size,
stride,
padding,
weight_attr=weight_attr,
bias_attr=bias_attr)
return conv
@GENERATORS.register()
class PReNet(nn.Layer):
"""
Args:
recurrent_iter (int): Number of iterations.
Default: 6.
use_GPU (bool): whether use gpu or not .
Default: True.
"""
def __init__(self, recurrent_iter=6, use_GPU=True):
super(PReNet, self).__init__()
self.iteration = recurrent_iter
self.use_GPU = use_GPU
self.conv0 = nn.Sequential(convWithBias(6, 32, 3, 1, 1), nn.ReLU())
self.res_conv1 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.res_conv2 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.res_conv3 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.res_conv4 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.res_conv5 = nn.Sequential(convWithBias(32, 32, 3, 1, 1), nn.ReLU(),
convWithBias(32, 32, 3, 1, 1), nn.ReLU())
self.conv_i = nn.Sequential(convWithBias(32 + 32, 32, 3, 1, 1),
nn.Sigmoid())
self.conv_f = nn.Sequential(convWithBias(32 + 32, 32, 3, 1, 1),
nn.Sigmoid())
self.conv_g = nn.Sequential(convWithBias(32 + 32, 32, 3, 1, 1),
nn.Tanh())
self.conv_o = nn.Sequential(convWithBias(32 + 32, 32, 3, 1, 1),
nn.Sigmoid())
self.conv = nn.Sequential(convWithBias(32, 3, 3, 1, 1), )
def forward(self, input):
batch_size, row, col = input.shape[0], input.shape[2], input.shape[3]
x = input
h = paddle.to_tensor(paddle.zeros(shape=(batch_size, 32, row, col),
dtype='float32'),
stop_gradient=False)
c = paddle.to_tensor(paddle.zeros(shape=(batch_size, 32, row, col),
dtype='float32'),
stop_gradient=False)
x_list = []
for _ in range(self.iteration):
x = paddle.concat((input, x), 1)
x = self.conv0(x)
x = paddle.concat((x, h), 1)
i = self.conv_i(x)
f = self.conv_f(x)
g = self.conv_g(x)
o = self.conv_o(x)
c = f * c + i * g
h = o * paddle.tanh(c)
x = h
resx = x
x = F.relu(self.res_conv1(x) + resx)
resx = x
x = F.relu(self.res_conv2(x) + resx)
resx = x
x = F.relu(self.res_conv3(x) + resx)
resx = x
x = F.relu(self.res_conv4(x) + resx)
resx = x
x = F.relu(self.res_conv5(x) + resx)
x = self.conv(x)
x = x + input
x_list.append(x)
return x
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .builder import MODELS
from .sr_model import BaseSRModel
from .generators.iconvsr import EDVRFeatureExtractor
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet
from ..modules.init import reset_parameters
from ..utils.visual import tensor2img
@MODELS.register()
class PReNetModel(BaseSRModel):
"""PReNet Model.
Paper: Progressive Image Deraining Networks: A Better and Simpler Baseline, IEEE,2019
"""
def __init__(self, generator, pixel_criterion=None):
"""Initialize the BasicVSR class.
Args:
generator (dict): config of generator.
fix_iter (dict): config of fix_iter.
pixel_criterion (dict): config of pixel criterion.
"""
super(PReNetModel, self).__init__(generator, pixel_criterion)
self.current_iter = 1
self.flag = True
def setup_input(self, input):
self.lq = input['lq']
self.visual_items['lq'] = self.lq[0, :, :, :]
if 'gt' in input:
self.gt = input['gt']
self.visual_items['gt'] = self.gt[0, :, :, :]
self.image_paths = input['lq_path']
def train_iter(self, optims=None):
optims['optim'].clear_grad()
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output[0, :, :, :]
# pixel loss
loss_pixel = -self.pixel_criterion(self.output, self.gt)
loss_pixel.backward()
optims['optim'].step()
self.losses['loss_pixel'] = loss_pixel
self.current_iter += 1
def test_iter(self, metrics=None):
self.gt = self.gt.cpu()
self.nets['generator'].eval()
with paddle.no_grad():
output = self.nets['generator'](self.lq)
self.visual_items['output'] = output[0, :, :, :].cpu()
self.nets['generator'].train()
out_img = []
gt_img = []
out_tensor = output[0]
gt_tensor = self.gt[0]
out_img = tensor2img(out_tensor, (0., 1.))
gt_img = tensor2img(gt_tensor, (0., 1.))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img, is_seq=True)
......@@ -15,7 +15,7 @@ from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan"]
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan","prenet"]
def parse_args():
......@@ -56,63 +56,55 @@ def parse_args():
parser.add_argument('--seed',
type=int,
default=None,
help='fix random numbers by setting seed\".'
)
help='fix random numbers by setting seed\".')
# for tensorRT
parser.add_argument(
"--run_mode",
default="fluid",
type=str,
choices=["fluid", "trt_fp32", "trt_fp16"],
help="mode of running(fluid/trt_fp32/trt_fp16)")
parser.add_argument(
"--trt_min_shape",
default=1,
type=int,
help="trt_min_shape for tensorRT")
parser.add_argument(
"--trt_max_shape",
default=1280,
type=int,
help="trt_max_shape for tensorRT")
parser.add_argument(
"--trt_opt_shape",
default=640,
type=int,
help="trt_opt_shape for tensorRT")
parser.add_argument(
"--min_subgraph_size",
default=3,
type=int,
help="trt_opt_shape for tensorRT")
parser.add_argument(
"--batch_size",
default=1,
type=int,
help="batch_size for tensorRT")
parser.add_argument(
"--use_dynamic_shape",
dest="use_dynamic_shape",
action="store_true",
help="use_dynamic_shape for tensorRT")
parser.add_argument(
"--trt_calib_mode",
dest="trt_calib_mode",
action="store_true",
help="trt_calib_mode for tensorRT")
parser.add_argument("--run_mode",
default="fluid",
type=str,
choices=["fluid", "trt_fp32", "trt_fp16"],
help="mode of running(fluid/trt_fp32/trt_fp16)")
parser.add_argument("--trt_min_shape",
default=1,
type=int,
help="trt_min_shape for tensorRT")
parser.add_argument("--trt_max_shape",
default=1280,
type=int,
help="trt_max_shape for tensorRT")
parser.add_argument("--trt_opt_shape",
default=640,
type=int,
help="trt_opt_shape for tensorRT")
parser.add_argument("--min_subgraph_size",
default=3,
type=int,
help="trt_opt_shape for tensorRT")
parser.add_argument("--batch_size",
default=1,
type=int,
help="batch_size for tensorRT")
parser.add_argument("--use_dynamic_shape",
dest="use_dynamic_shape",
action="store_true",
help="use_dynamic_shape for tensorRT")
parser.add_argument("--trt_calib_mode",
dest="trt_calib_mode",
action="store_true",
help="trt_calib_mode for tensorRT")
args = parser.parse_args()
return args
def create_predictor(model_path, device="gpu",
run_mode='fluid',
batch_size=1,
min_subgraph_size=3,
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False):
def create_predictor(model_path,
device="gpu",
run_mode='fluid',
batch_size=1,
min_subgraph_size=3,
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False):
config = paddle.inference.Config(model_path + ".pdmodel",
model_path + ".pdiparams")
if device == "gpu":
......@@ -123,20 +115,19 @@ def create_predictor(model_path, device="gpu",
config.enable_xpu(100)
else:
config.disable_gpu()
precision_map = {
'trt_int8': paddle.inference.Config.Precision.Int8,
'trt_fp32': paddle.inference.Config.Precision.Float32,
'trt_fp16': paddle.inference.Config.Precision.Half
}
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=1 << 25,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
config.enable_tensorrt_engine(workspace_size=1 << 25,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if use_dynamic_shape:
min_input_shape = {
......@@ -155,6 +146,7 @@ def create_predictor(model_path, device="gpu",
predictor = paddle.inference.create_predictor(config)
return predictor
def setup_metrics(cfg):
metrics = OrderedDict()
if isinstance(list(cfg.values())[0], dict):
......@@ -166,22 +158,18 @@ def setup_metrics(cfg):
return metrics
def main():
args = parse_args()
if args.seed:
paddle.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
np.random.seed(args.seed)
cfg = get_config(args.config_file, args.opt)
predictor = create_predictor(args.model_path,
args.device,
args.run_mode,
args.batch_size,
args.min_subgraph_size,
args.use_dynamic_shape,
args.trt_min_shape,
args.trt_max_shape,
args.trt_opt_shape,
predictor = create_predictor(args.model_path, args.device, args.run_mode,
args.batch_size, args.min_subgraph_size,
args.use_dynamic_shape, args.trt_min_shape,
args.trt_max_shape, args.trt_opt_shape,
args.trt_calib_mode)
input_handles = [
predictor.get_input_handle(name)
......@@ -218,7 +206,9 @@ def main():
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "pix2pix/{}.png".format(i)))
save_image(
image_numpy,
os.path.join(args.output_path, "pix2pix/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "pix2pix/metric.txt")
real_B = paddle.to_tensor(data['A'])
for metric in metrics.values():
......@@ -231,7 +221,9 @@ def main():
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "cyclegan/{}.png".format(i)))
save_image(
image_numpy,
os.path.join(args.output_path, "cyclegan/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "cyclegan/metric.txt")
real_B = paddle.to_tensor(data['B'])
for metric in metrics.values():
......@@ -275,7 +267,9 @@ def main():
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "stylegan2/{}.png".format(i)))
save_image(
image_numpy,
os.path.join(args.output_path, "stylegan2/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "stylegan2/metric.txt")
real_img = paddle.to_tensor(data['A'])
for metric in metrics.values():
......@@ -285,7 +279,8 @@ def main():
input_handles[0].copy_from_cpu(lq)
predictor.run()
if len(predictor.get_output_names()) > 1:
output_handle = predictor.get_output_handle(predictor.get_output_names()[-1])
output_handle = predictor.get_output_handle(
predictor.get_output_names()[-1])
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
_, t, _, _, _ = prediction.shape
......@@ -295,13 +290,16 @@ def main():
for ti in range(t):
out_tensor = prediction[0, ti]
gt_tensor = data['gt'][0, ti]
out_img.append(tensor2img(out_tensor, (0.,1.)))
gt_img.append(tensor2img(gt_tensor, (0.,1.)))
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, model_type, "{}.png".format(i)))
save_image(
image_numpy,
os.path.join(args.output_path, model_type, "{}.png".format(i)))
metric_file = os.path.join(args.output_path, model_type, "metric.txt")
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(out_img, gt_img, is_seq=True)
elif model_type == "singan":
......@@ -309,18 +307,38 @@ def main():
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, os.path.join(args.output_path, "singan/{}.png".format(i)))
save_image(
image_numpy,
os.path.join(args.output_path, "singan/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "singan/metric.txt")
for metric in metrics.values():
metric.update(prediction, data['A'])
elif model_type == "prenet":
lq = data['lq'].numpy()
gt = data['gt'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
gt = paddle.to_tensor(gt)
image_numpy = tensor2img(prediction, min_max)
gt_img = tensor2img(gt, min_max)
save_image(
image_numpy,
os.path.join(args.output_path, "prenet/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "prenet/metric.txt")
for metric in metrics.values():
metric.update(image_numpy, gt_img)
if metrics:
log_file = open(metric_file, 'a')
for metric_name, metric in metrics.items():
loss_string = "Metric {}: {:.4f}".format(
metric_name, metric.accumulate())
loss_string = "Metric {}: {:.4f}".format(metric_name,
metric.accumulate())
print(loss_string, file=log_file)
log_file.close()
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册