未验证 提交 95e5f4f3 编写于 作者: L LielinJiang 提交者: GitHub

Add drn model (#153)

* add drn model
上级 827d0cfc
total_iters: 1000000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 255.)
model:
name: DRN
generator:
name: DRNGenerator
scale: (2, 4)
n_blocks: 30
n_feats: 16
n_colors: 3
rgb_range: 255
negval: 0.2
pixel_criterion:
name: L1Loss
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers: 4
batch_size: 8
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
output_keys: [lq, lqx2, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 384
scale: 4
scale_list: True
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image, image]
- name: PairedRandomVerticalFlip
keys: [image, image, image]
- name: PairedRandomTransposeHW
keys: [image, image, image]
- name: Transpose
keys: [image, image, image]
- name: Normalize
mean: [0., 0., 0.]
std: [1., 1., 1.]
keys: [image, image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., 0., 0.]
std: [1., 1., 1.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: 0.0001
periods: [1000000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
optimG:
name: Adam
net_names:
- generator
weight_decay: 0.0
beta1: 0.9
beta2: 0.999
optimD:
name: Adam
net_names:
- dual_model_0
- dual_model_1
weight_decay: 0.0
beta1: 0.9
beta2: 0.999
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
......@@ -2,11 +2,11 @@
## 1.1 Principle
Super resolution is a process of upscaling and improving the details within an image. It usually takes a low-resolution image as input and upscales the same image to a higher resolution as output.
Here we provide three super-resolution models, namely [RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf), [ESRGAN](https://arxiv.org/abs/1809.00219v2), [LESRCNN](https://arxiv.org/abs/2007.04344).
Super resolution is a process of upscaling and improving the details within an image. It usually takes a low-resolution image as input and upscales the same image to a higher resolution as output.
Here we provide three super-resolution models, namely [RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf), [ESRGAN](https://arxiv.org/abs/1809.00219v2), [LESRCNN](https://arxiv.org/abs/2007.04344).
[RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf) proposed a realworld super-resolution model aiming at better perception.
[ESRGAN](https://arxiv.org/abs/1809.00219v2) is an enhanced SRGAN that improves the three key components of SRGAN.
[LESRCNN](https://arxiv.org/abs/2007.04344) is a lightweight enhanced SR CNN (LESRCNN) with three successive sub-blocks.
[LESRCNN](https://arxiv.org/abs/2007.04344) is a lightweight enhanced SR CNN (LESRCNN) with three successive sub-blocks.
## 1.2 How to use
......@@ -32,7 +32,7 @@
├── DIV2K_valid_LR_bicubic
...
```
The structures of Set5 and Set14 are similar. Taking Set5 as an example, the structure is as following:
```
Set5
......@@ -71,7 +71,7 @@ The metrics are PSNR / SSIM.
| lesrcnn_x4 | 31.9476 / 0.8909 | 28.4110 / 0.7770 | 30.231 / 0.8326 |
| esrgan_psnr_x4 | 32.5512 / 0.8991 | 28.8114 / 0.7871 | 30.7565 / 0.8449 |
| esrgan_x4 | 28.7647 / 0.8187 | 25.0065 / 0.6762 | 26.9013 / 0.7542 |
<!-- ![](../../imgs/horse2zebra.png) -->
......@@ -85,6 +85,7 @@ The metrics are PSNR / SSIM.
| lesrcnn_x4 | DIV2K | [lesrcnn_x4](https://paddlegan.bj.bcebos.com/models/lesrcnn_x4.pdparams)
| esrgan_psnr_x4 | DIV2K | [esrgan_psnr_x4](https://paddlegan.bj.bcebos.com/models/esrgan_psnr_x4.pdparams)
| esrgan_x4 | DIV2K | [esrgan_x4](https://paddlegan.bj.bcebos.com/models/esrgan_x4.pdparams)
| drns_x4 | DIV2K | [drns_x4](https://paddlegan.bj.bcebos.com/models/DRNSx4.pdparams)
# References
......@@ -126,3 +127,13 @@ The metrics are PSNR / SSIM.
publisher={Elsevier}
}
```
- 4. [Closed-loop Matters: Dual Regression Networks for Single Image Super-Resolution](https://arxiv.org/pdf/2003.07018.pdf)
```
@inproceedings{guo2020closed,
title={Closed-loop Matters: Dual Regression Networks for Single Image Super-Resolution},
author={Guo, Yong and Chen, Jian and Wang, Jingdong and Chen, Qi and Cao, Jiezhang and Deng, Zeshuai and Xu, Yanwu and Tan, Mingkui},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}
```
......@@ -21,6 +21,7 @@ from .makeup_model import MakeupModel
from .esrgan_model import ESRGAN
from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel
from .drn_model import DRN
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
from .styleganv2_model import StyleGAN2Model
from .wav2lip_model import Wav2LipModel
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .generators.drn import DownBlock
from .sr_model import BaseSRModel
from .builder import MODELS
from .criterions import build_criterion
from ..modules.init import init_weights
from ..utils.visual import tensor2img
@MODELS.register()
class DRN(BaseSRModel):
"""
This class implements the DRN model.
DRN paper: https://arxiv.org/pdf/1809.00219.pdf
"""
def __init__(self,
generator,
lq_loss_weight=0.1,
dual_loss_weight=0.1,
discriminator=None,
pixel_criterion=None,
perceptual_criterion=None,
gan_criterion=None,
params=None):
"""Initialize the DRN class.
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
pixel_criterion (dict): config of pixel criterion.
perceptual_criterion (dict): config of perceptual criterion.
gan_criterion (dict): config of gan criterion.
"""
super(DRN, self).__init__(generator)
self.lq_loss_weight = lq_loss_weight
self.dual_loss_weight = dual_loss_weight
self.params = params
self.nets['generator'] = build_generator(generator)
init_weights(self.nets['generator'])
negval = generator.negval
n_feats = generator.n_feats
n_colors = generator.n_colors
self.scale = generator.scale
for i in range(len(self.scale)):
dual_model = DownBlock(negval, n_feats, n_colors, 2)
self.nets['dual_model_' + str(i)] = dual_model
init_weights(self.nets['dual_model_' + str(i)])
if discriminator:
self.nets['discriminator'] = build_discriminator(discriminator)
if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)
if perceptual_criterion:
self.perceptual_criterion = build_criterion(perceptual_criterion)
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
def setup_input(self, input):
self.lq = paddle.fluid.dygraph.to_variable(input['lq'])
self.visual_items['lq'] = self.lq
if isinstance(self.scale, (list, tuple)) and len(
self.scale) == 2 and 'lqx2' in input:
self.lqx2 = input['lqx2']
if 'gt' in input:
self.gt = paddle.fluid.dygraph.to_variable(input['gt'])
self.visual_items['gt'] = self.gt
self.image_paths = input['lq_path']
def train_iter(self, optimizers=None):
lr = [self.lq]
if hasattr(self, 'lqx2'):
lr.append(self.lqx2)
hr = self.gt
sr = self.nets['generator'](self.lq)
sr2lr = []
for i in range(len(self.scale)):
sr2lr_i = self.nets['dual_model_' + str(i)](sr[i - len(self.scale)])
sr2lr.append(sr2lr_i)
# compute primary loss
loss_primary = self.pixel_criterion(sr[-1], hr)
for i in range(1, len(sr)):
if self.lq_loss_weight > 0.0:
loss_primary += self.pixel_criterion(
sr[i - 1 - len(sr)], lr[i - len(sr)]) * self.lq_loss_weight
# compute dual loss
loss_dual = self.pixel_criterion(sr2lr[0], lr[0])
for i in range(1, len(self.scale)):
if self.dual_loss_weight > 0.0:
loss_dual += self.pixel_criterion(sr2lr[i],
lr[i]) * self.dual_loss_weight
loss_total = loss_primary + loss_dual
optimizers['optimG'].clear_grad()
optimizers['optimD'].clear_grad()
loss_total.backward()
optimizers['optimG'].step()
optimizers['optimD'].step()
self.losses['loss_promary'] = loss_primary
self.losses['loss_dual'] = loss_dual
self.losses['loss_total'] = loss_total
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)[-1]
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.gt):
out_img.append(tensor2img(out_tensor, (0., 255.)))
gt_img.append(tensor2img(gt_tensor, (0., 255.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
......@@ -25,3 +25,4 @@ from .lesrcnn import LESRCNNGenerator
from .resnet_ugatit_p2c import ResnetUGATITP2CGenerator
from .generator_styleganv2 import StyleGANv2Generator
from .generator_pixel2style2pixel import Pixel2Style2Pixel
from .drn import DRNGenerator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math
import paddle
import paddle.nn as nn
from .builder import GENERATORS
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2D(in_channels,
out_channels,
kernel_size,
padding=(kernel_size // 2),
bias_attr=bias)
class MeanShift(nn.Conv2D):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = paddle.to_tensor(rgb_std)
self.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1]))
self.weight.set_value(self.weight / (std.reshape([3, 1, 1, 1])))
mean = paddle.to_tensor(rgb_mean)
self.bias.set_value(sign * rgb_range * mean / std)
self.weight.trainable = False
self.bias.trainable = False
class DownBlock(nn.Layer):
def __init__(self,
negval,
n_feats,
n_colors,
scale,
nFeat=None,
in_channels=None,
out_channels=None):
super(DownBlock, self).__init__()
if nFeat is None:
nFeat = n_feats
if in_channels is None:
in_channels = n_colors
if out_channels is None:
out_channels = n_colors
dual_block = [
nn.Sequential(
nn.Conv2D(in_channels,
nFeat,
kernel_size=3,
stride=2,
padding=1,
bias_attr=False), nn.LeakyReLU(negative_slope=negval))
]
for _ in range(1, int(math.log2(scale))):
dual_block.append(
nn.Sequential(
nn.Conv2D(nFeat,
nFeat,
kernel_size=3,
stride=2,
padding=1,
bias_attr=False),
nn.LeakyReLU(negative_slope=negval)))
dual_block.append(
nn.Conv2D(nFeat,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False))
self.dual_module = nn.Sequential(*dual_block)
def forward(self, x):
x = self.dual_module(x)
return x
## Channel Attention (CA) Layer
class CALayer(nn.Layer):
def __init__(self, channel, reduction=16):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2D(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2D(channel,
channel // reduction,
1,
padding=0,
bias_attr=True), nn.ReLU(),
nn.Conv2D(channel // reduction,
channel,
1,
padding=0,
bias_attr=True), nn.Sigmoid())
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
class RCAB(nn.Layer):
def __init__(self,
conv,
n_feat,
kernel_size,
reduction=16,
bias=True,
bn=False,
act=nn.ReLU(),
res_scale=1):
super(RCAB, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: modules_body.append(nn.BatchNorm2D(n_feat))
if i == 0: modules_body.append(act)
modules_body.append(CALayer(n_feat, reduction))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2D(n_feats))
if act == 'relu':
m.append(nn.ReLU())
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2D(n_feats))
if act == 'relu':
m.append(nn.ReLU())
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
@GENERATORS.register()
class DRNGenerator(nn.Layer):
"""DRNGenerator"""
def __init__(
self,
scale,
n_blocks=30,
n_feats=16,
n_colors=3,
rgb_range=255,
negval=0.2,
kernel_size=3,
conv=default_conv,
):
super(DRNGenerator, self).__init__()
self.scale = scale
self.phase = len(scale)
act = nn.ReLU()
self.upsample = nn.Upsample(scale_factor=max(scale),
mode='bicubic',
align_corners=False)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
self.head = conv(n_colors, n_feats, kernel_size)
self.down = [
DownBlock(negval, n_feats, n_colors, 2, n_feats * pow(2, p),
n_feats * pow(2, p), n_feats * pow(2, p + 1))
for p in range(self.phase)
]
self.down = nn.LayerList(self.down)
up_body_blocks = [[
RCAB(conv, n_feats * pow(2, p), kernel_size, act=act)
for _ in range(n_blocks)
] for p in range(self.phase, 1, -1)]
up_body_blocks.insert(0, [
RCAB(conv, n_feats * pow(2, self.phase), kernel_size, act=act)
for _ in range(n_blocks)
])
# The fisrt upsample block
up = [[
Upsampler(conv, 2, n_feats * pow(2, self.phase), act=False),
conv(n_feats * pow(2, self.phase),
n_feats * pow(2, self.phase - 1),
kernel_size=1)
]]
# The rest upsample blocks
for p in range(self.phase - 1, 0, -1):
up.append([
Upsampler(conv, 2, 2 * n_feats * pow(2, p), act=False),
conv(2 * n_feats * pow(2, p),
n_feats * pow(2, p - 1),
kernel_size=1)
])
self.up_blocks = nn.LayerList()
for idx in range(self.phase):
self.up_blocks.append(nn.Sequential(*up_body_blocks[idx], *up[idx]))
# tail conv that output sr imgs
tail = [conv(n_feats * pow(2, self.phase), n_colors, kernel_size)]
for p in range(self.phase, 0, -1):
tail.append(conv(n_feats * pow(2, p), n_colors, kernel_size))
self.tail = nn.LayerList(tail)
self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# upsample x to target sr size
x = self.upsample(x)
# preprocess
x = self.sub_mean(x)
x = self.head(x)
# down phases,
copies = []
for idx in range(self.phase):
copies.append(x)
x = self.down[idx](x)
# up phases
sr = self.tail[0](x)
sr = self.add_mean(sr)
results = [sr]
for idx in range(self.phase):
# upsample to SR features
x = self.up_blocks[idx](x)
# concat down features and upsample features
x = paddle.concat((x, copies[self.phase - idx - 1]), 1)
# output sr imgs
sr = self.tail[idx + 1](x)
sr = self.add_mean(sr)
results.append(sr)
return results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册