未验证 提交 1e1e2ad2 编写于 作者: K kongdebug 提交者: GitHub

[feature] add rcan model for remote sensing image super-resolution (#610)

* [feature] add rcan model for super-resolution
上级 413376be
total_iters: 1000000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 255.)
model:
name: RCANModel
generator:
name: RCAN
scale: 4
n_resgroups: 10
n_resblocks: 20
pixel_criterion:
name: L1Loss
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers: 4
batch_size: 16
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 192
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [1., 1., 1.]
keys: [image, image]
test:
name: SRDataset
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [1., 1., 1.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: 0.0001
periods: [1000000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 2500
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 2500
# 1.单幅遥感图像超分辨率重建
## 1.1 背景和原理介绍
**意义与应用场景**:单幅影像超分辨率重建一直是low-level视觉领域中一个比较热门的任务,其可以成为修复老电影、老照片的技术手段,也可以为图像分割、目标检测等下游任务提供质量较高的数据。在遥感中的应用场景也比较广泛,例如:在**船舶检测和分类**等诸多遥感影像应用中,**提高遥感影像分辨率具有重要意义**
**原理**:单幅遥感影像的超分辨率重建本质上与单幅影像超分辨率重建类似,均是使用RGB三通道的低分辨率影像生成纹理清晰的高分辨率影像。本项目复现的论文是[Yulun Zhang](http://yulunzhang.com/), [Kunpeng Li](https://kunpengli1994.github.io/), [Kai Li](http://kailigo.github.io/), [Lichen Wang](https://sites.google.com/site/lichenwang123/), [Bineng Zhong](https://scholar.google.de/citations?user=hvRBydsAAAAJ&hl=en), and [Yun Fu](http://www1.ece.neu.edu/~yunfu/), 发表在ECCV 2018上的论文[《Image Super-Resolution Using Very Deep Residual Channel Attention Networks》](https://arxiv.org/abs/1807.02758)
作者提出了一个深度残差通道注意力网络(RCAN),引入一种通道注意力机制(CA),通过考虑通道之间的相互依赖性来自适应地重新调整特征。该模型取得优异的性能,因此本项目选择RCAN进行单幅遥感影像的x4超分辨率重建。
## 1.2 如何使用
### 1.2.1 数据准备
本项目的训练分为两个阶段,第一个阶段使用[DIV2K数据集](https://data.vision.ee.ethz.ch/cvl/DIV2K/)进行预训练RCANx4模型,然后基于该模型再使用[遥感超分数据集合](https://aistudio.baidu.com/aistudio/datasetdetail/129011)进行迁移学习。
- 关于DIV2K数据的准备方法参考[该文档](./single_image_super_resolution.md)
- 遥感超分数据准备
- 数据已经上传至AI studio中,该数据为从UC Merced Land-Use Dataset 21 级土地利用图像遥感数据集中抽取部分遥感影像,通过BI退化生成的HR-LR影像对用于训练超分模型,其中训练集6720对,测试集420对
- 下载解压后的文件组织形式如下
```
├── RSdata_for_SR
├── train_HR
├── train_LR
| └──x4
├── test_HR
├── test_LR
| └──x4
```
### 1.2.2 DIV2K数据集上训练/测试
首先是在DIV2K数据集上训练RCANx4模型,并以Set14作为测试集。按照论文需要准备RCANx2作为初始化权重,可通过下表进行获取。
| 模型 | 数据集 | 下载地址 |
|---|---|---|
| RCANx2 | DIV2K | [RCANx2](https://paddlegan.bj.bcebos.com/models/RCAN_X2_DIV2K.pdparams)
将DIV2K数据按照 [该文档](./single_image_super_resolution.md)所示准备好后,执行以下命令训练模型,`--load`的参数为下载好的RCANx2模型权重所在路径。
```shell
python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_WEIGHT}
```
训练好后,执行以下命令可对测试集Set14预测,`--load`的参数为训练好的RCANx4模型权重
```shell
python tools/main.py --config-file configs/rcan_rssr_x4.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
本项目在DIV2K数据集训练迭代第57250次得到的权重[RCAN_X4_DIV2K](https://pan.baidu.com/s/1rI7yUdD4T1DE0RZB5yHXjA)(提取码:aglw),在Set14数据集上测得的精度:`PSNR:28.8959 SSIM:0.7896`
### 1.2.3 遥感超分数据上迁移学习训练/测试
- 使用该数据集,需要修改`rcan_rssr_x4.yaml`文件中训练集与测试集的高分辨率图像路径和低分辨率图像路径,即文件中的`gt_folder``lq_folder`
- 同时,由于使用了在DIV2K数据集上训练的RCAN_X4_DIV2K模型权重来进行迁移学习,所以训练的迭代次数`total_iters`也可以进行修改,并不需要很多次数的迭代就能有良好的效果。训练模型中`--load`的参数为下载好的RCANx4模型权重所在路径。
训练模型:
```shell
python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_RCANx4_WEIGHT}
```
测试模型:
```shell
python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_RCANx4_WEIGHT}
```
## 1.3 实验结果
- RCANx4遥感影像超分效果
<img src=../../imgs/RSSR.png></img>
- [RCAN遥感影像超分辨率重建 Ai studio 项目在线体验](https://aistudio.baidu.com/aistudio/projectdetail/3508912)
......@@ -35,4 +35,5 @@ from .mpr_model import MPRModel
from .photopen_model import PhotoPenModel
from .msvsr_model import MultiStageVSRModel
from .singan_model import SinGANModel
from .rcan_model import RCANModel
from .prenet_model import PReNetModel
......@@ -39,4 +39,5 @@ from .generater_photopen import SPADEGenerator
from .basicvsr_plus_plus import BasicVSRPlusPlus
from .msvsr import MSVSR
from .generator_singan import SinGANGenerator
from .rcan import RCAN
from .prenet import PReNet
# base on https://github.com/kongdebug/RCAN-Paddle
import math
import paddle
import paddle.nn as nn
from .builder import GENERATORS
def default_conv(in_channels, out_channels, kernel_size, bias=True):
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierUniform(), need_clip=True)
return nn.Conv2D(in_channels,
out_channels,
kernel_size,
padding=(kernel_size // 2),
weight_attr=weight_attr,
bias_attr=bias)
class MeanShift(nn.Conv2D):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = paddle.to_tensor(rgb_std)
self.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1]))
self.weight.set_value(self.weight / (std.reshape([3, 1, 1, 1])))
mean = paddle.to_tensor(rgb_mean)
self.bias.set_value(sign * rgb_range * mean / std)
self.weight.trainable = False
self.bias.trainable = False
## Channel Attention (CA) Layer
class CALayer(nn.Layer):
def __init__(self, channel, reduction=16):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2D(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2D(channel,
channel // reduction,
1,
padding=0,
bias_attr=True), nn.ReLU(),
nn.Conv2D(channel // reduction,
channel,
1,
padding=0,
bias_attr=True), nn.Sigmoid())
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
class RCAB(nn.Layer):
def __init__(self,
conv,
n_feat,
kernel_size,
reduction=16,
bias=True,
bn=False,
act=nn.ReLU(),
res_scale=1):
super(RCAB, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: modules_body.append(nn.BatchNorm2D(n_feat))
if i == 0: modules_body.append(act)
modules_body.append(CALayer(n_feat, reduction))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
res += x
return res
## Residual Group (RG)
class ResidualGroup(nn.Layer):
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale,
n_resblocks):
super(ResidualGroup, self).__init__()
modules_body = []
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
for _ in range(n_resblocks)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2D(n_feats))
if act == 'relu':
m.append(nn.ReLU())
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2D(n_feats))
if act == 'relu':
m.append(nn.ReLU())
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
@GENERATORS.register()
class RCAN(nn.Layer):
def __init__(
self,
scale,
n_resgroups,
n_resblocks,
n_feats=64,
n_colors=3,
rgb_range=255,
kernel_size=3,
reduction=16,
conv=default_conv,
):
super(RCAN, self).__init__()
self.scale = scale
act = nn.ReLU()
n_resgroups = n_resgroups
n_resblocks = n_resblocks
n_feats = n_feats
kernel_size = kernel_size
reduction = reduction
scale = scale
act = nn.ReLU()
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
# define head module
modules_head = [conv(n_colors, n_feats, kernel_size)]
# define body module
modules_body = [
ResidualGroup(
conv, n_feats, kernel_size, reduction, act=act, res_scale= 1, n_resblocks=n_resblocks) \
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
modules_tail = [
Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, n_colors, kernel_size)
]
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from .base_model import BaseModel
from .builder import MODELS
from ..utils.visual import tensor2img
from ..modules.init import reset_parameters
@MODELS.register()
class RCANModel(BaseModel):
"""Base SR model for single image super-resolution.
"""
def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
"""
Args:
generator (dict): config of generator.
pixel_criterion (dict): config of pixel criterion.
"""
super(RCANModel, self).__init__()
self.nets['generator'] = build_generator(generator)
self.error_last = 1e8
self.batch = 0
if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)
if use_init_weight:
init_sr_weight(self.nets['generator'])
def setup_input(self, input):
self.lq = paddle.to_tensor(input['lq'])
self.visual_items['lq'] = self.lq
if 'gt' in input:
self.gt = paddle.to_tensor(input['gt'])
self.visual_items['gt'] = self.gt
self.image_paths = input['lq_path']
def forward(self):
pass
def train_iter(self, optims=None):
optims['optim'].clear_grad()
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
# pixel loss
loss_pixel = self.pixel_criterion(self.output, self.gt)
self.losses['loss_pixel'] = loss_pixel
skip_threshold = 1e6
if loss_pixel.item() < skip_threshold * self.error_last:
loss_pixel.backward()
optims['optim'].step()
else:
print('Skip this batch {}! (Loss: {})'.format(
self.batch + 1, loss_pixel.item()))
self.batch += 1
if self.batch % 1000 == 0:
self.error_last = loss_pixel.item() / 1000
print("update error_last:{}".format(self.error_last))
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.gt):
out_img.append(tensor2img(out_tensor, (0., 255.)))
gt_img.append(tensor2img(gt_tensor, (0., 255.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
def init_sr_weight(net):
def reset_func(m):
if hasattr(m, 'weight') and (not isinstance(
m, (nn.BatchNorm, nn.BatchNorm2D))):
reset_parameters(m)
net.apply(reset_func)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册