diff --git a/configs/rcan_rssr_x4.yaml b/configs/rcan_rssr_x4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9be1561832c674d6c163223ea918f49817e94f20 --- /dev/null +++ b/configs/rcan_rssr_x4.yaml @@ -0,0 +1,104 @@ +total_iters: 1000000 +output_dir: output_dir +# tensor range for function tensor2img +min_max: + (0., 255.) + +model: + name: RCANModel + generator: + name: RCAN + scale: 4 + n_resgroups: 10 + n_resblocks: 20 + pixel_criterion: + name: L1Loss + +dataset: + train: + name: SRDataset + gt_folder: data/DIV2K/DIV2K_train_HR_sub + lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub + num_workers: 4 + batch_size: 16 + scale: 4 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 192 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [1., 1., 1.] + keys: [image, image] + test: + name: SRDataset + gt_folder: data/Set14/GTmod12 + lq_folder: data/Set14/LRbicx4 + scale: 4 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [1., 1., 1.] + keys: [image, image] + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: 0.0001 + periods: [1000000] + restart_weights: [1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 2500 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 4 + test_y_channel: True + ssim: + name: SSIM + crop_border: 4 + test_y_channel: True + +log_config: + interval: 10 + visiual_interval: 5000 + +snapshot_config: + interval: 2500 diff --git a/docs/imgs/RSSR.png b/docs/imgs/RSSR.png new file mode 100644 index 0000000000000000000000000000000000000000..dc4147740a3c19c32145c3def401f688395aec59 Binary files /dev/null and b/docs/imgs/RSSR.png differ diff --git a/docs/zh_CN/tutorials/remote_sensing_image_super-resolution.md b/docs/zh_CN/tutorials/remote_sensing_image_super-resolution.md new file mode 100644 index 0000000000000000000000000000000000000000..3fa4b537754b9e6db4ffc0eb35605014bc95fbb6 --- /dev/null +++ b/docs/zh_CN/tutorials/remote_sensing_image_super-resolution.md @@ -0,0 +1,70 @@ +# 1.单幅遥感图像超分辨率重建 + +## 1.1 背景和原理介绍 + + **意义与应用场景**:单幅影像超分辨率重建一直是low-level视觉领域中一个比较热门的任务,其可以成为修复老电影、老照片的技术手段,也可以为图像分割、目标检测等下游任务提供质量较高的数据。在遥感中的应用场景也比较广泛,例如:在**船舶检测和分类**等诸多遥感影像应用中,**提高遥感影像分辨率具有重要意义**。 + +**原理**:单幅遥感影像的超分辨率重建本质上与单幅影像超分辨率重建类似,均是使用RGB三通道的低分辨率影像生成纹理清晰的高分辨率影像。本项目复现的论文是[Yulun Zhang](http://yulunzhang.com/), [Kunpeng Li](https://kunpengli1994.github.io/), [Kai Li](http://kailigo.github.io/), [Lichen Wang](https://sites.google.com/site/lichenwang123/), [Bineng Zhong](https://scholar.google.de/citations?user=hvRBydsAAAAJ&hl=en), and [Yun Fu](http://www1.ece.neu.edu/~yunfu/), 发表在ECCV 2018上的论文[《Image Super-Resolution Using Very Deep Residual Channel Attention Networks》](https://arxiv.org/abs/1807.02758)。 +作者提出了一个深度残差通道注意力网络(RCAN),引入一种通道注意力机制(CA),通过考虑通道之间的相互依赖性来自适应地重新调整特征。该模型取得优异的性能,因此本项目选择RCAN进行单幅遥感影像的x4超分辨率重建。 + +## 1.2 如何使用 + +### 1.2.1 数据准备 + 本项目的训练分为两个阶段,第一个阶段使用[DIV2K数据集](https://data.vision.ee.ethz.ch/cvl/DIV2K/)进行预训练RCANx4模型,然后基于该模型再使用[遥感超分数据集合](https://aistudio.baidu.com/aistudio/datasetdetail/129011)进行迁移学习。 + - 关于DIV2K数据的准备方法参考[该文档](./single_image_super_resolution.md) + - 遥感超分数据准备 + - 数据已经上传至AI studio中,该数据为从UC Merced Land-Use Dataset 21 级土地利用图像遥感数据集中抽取部分遥感影像,通过BI退化生成的HR-LR影像对用于训练超分模型,其中训练集6720对,测试集420对 + - 下载解压后的文件组织形式如下 + ``` + ├── RSdata_for_SR + ├── train_HR + ├── train_LR + | └──x4 + ├── test_HR + ├── test_LR + | └──x4 + ``` + +### 1.2.2 DIV2K数据集上训练/测试 + +首先是在DIV2K数据集上训练RCANx4模型,并以Set14作为测试集。按照论文需要准备RCANx2作为初始化权重,可通过下表进行获取。 + +| 模型 | 数据集 | 下载地址 | +|---|---|---| +| RCANx2 | DIV2K | [RCANx2](https://paddlegan.bj.bcebos.com/models/RCAN_X2_DIV2K.pdparams) + + +将DIV2K数据按照 [该文档](./single_image_super_resolution.md)所示准备好后,执行以下命令训练模型,`--load`的参数为下载好的RCANx2模型权重所在路径。 + +```shell +python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_WEIGHT} +``` + +训练好后,执行以下命令可对测试集Set14预测,`--load`的参数为训练好的RCANx4模型权重 +```shell +python tools/main.py --config-file configs/rcan_rssr_x4.yaml --evaluate-only --load ${PATH_OF_WEIGHT} +``` + +本项目在DIV2K数据集训练迭代第57250次得到的权重[RCAN_X4_DIV2K](https://pan.baidu.com/s/1rI7yUdD4T1DE0RZB5yHXjA)(提取码:aglw),在Set14数据集上测得的精度:`PSNR:28.8959 SSIM:0.7896` + +### 1.2.3 遥感超分数据上迁移学习训练/测试 +- 使用该数据集,需要修改`rcan_rssr_x4.yaml`文件中训练集与测试集的高分辨率图像路径和低分辨率图像路径,即文件中的`gt_folder`和`lq_folder`。 +- 同时,由于使用了在DIV2K数据集上训练的RCAN_X4_DIV2K模型权重来进行迁移学习,所以训练的迭代次数`total_iters`也可以进行修改,并不需要很多次数的迭代就能有良好的效果。训练模型中`--load`的参数为下载好的RCANx4模型权重所在路径。 + +训练模型: +```shell +python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_RCANx4_WEIGHT} +``` +测试模型: +```shell +python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_RCANx4_WEIGHT} +``` + +## 1.3 实验结果 + +- RCANx4遥感影像超分效果 + + + +- [RCAN遥感影像超分辨率重建 Ai studio 项目在线体验](https://aistudio.baidu.com/aistudio/projectdetail/3508912) + diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index e9917769d254dea16a2e84d30da6f5a73d13b490..6f38794a8f1ed17d1979575a2ee960881fb1cdff 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -35,4 +35,5 @@ from .mpr_model import MPRModel from .photopen_model import PhotoPenModel from .msvsr_model import MultiStageVSRModel from .singan_model import SinGANModel +from .rcan_model import RCANModel from .prenet_model import PReNetModel diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 8cf04b9b93f263f02ab03851928bccc9639137f5..56572a79806951a58a6a4520cd740f57c2471000 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -39,4 +39,5 @@ from .generater_photopen import SPADEGenerator from .basicvsr_plus_plus import BasicVSRPlusPlus from .msvsr import MSVSR from .generator_singan import SinGANGenerator +from .rcan import RCAN from .prenet import PReNet diff --git a/ppgan/models/generators/rcan.py b/ppgan/models/generators/rcan.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe989b32e27821301f19ff2f00544ae44056f6d --- /dev/null +++ b/ppgan/models/generators/rcan.py @@ -0,0 +1,202 @@ +# base on https://github.com/kongdebug/RCAN-Paddle +import math +import paddle +import paddle.nn as nn + +from .builder import GENERATORS + + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.XavierUniform(), need_clip=True) + return nn.Conv2D(in_channels, + out_channels, + kernel_size, + padding=(kernel_size // 2), + weight_attr=weight_attr, + bias_attr=bias) + + +class MeanShift(nn.Conv2D): + + def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = paddle.to_tensor(rgb_std) + self.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1])) + self.weight.set_value(self.weight / (std.reshape([3, 1, 1, 1]))) + + mean = paddle.to_tensor(rgb_mean) + self.bias.set_value(sign * rgb_range * mean / std) + + self.weight.trainable = False + self.bias.trainable = False + + +## Channel Attention (CA) Layer +class CALayer(nn.Layer): + + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2D(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2D(channel, + channel // reduction, + 1, + padding=0, + bias_attr=True), nn.ReLU(), + nn.Conv2D(channel // reduction, + channel, + 1, + padding=0, + bias_attr=True), nn.Sigmoid()) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +class RCAB(nn.Layer): + + def __init__(self, + conv, + n_feat, + kernel_size, + reduction=16, + bias=True, + bn=False, + act=nn.ReLU(), + res_scale=1): + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2D(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + res += x + return res + + +## Residual Group (RG) +class ResidualGroup(nn.Layer): + + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, + n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class Upsampler(nn.Sequential): + + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: m.append(nn.BatchNorm2D(n_feats)) + + if act == 'relu': + m.append(nn.ReLU()) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: m.append(nn.BatchNorm2D(n_feats)) + + if act == 'relu': + m.append(nn.ReLU()) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) + + +@GENERATORS.register() +class RCAN(nn.Layer): + + def __init__( + self, + scale, + n_resgroups, + n_resblocks, + n_feats=64, + n_colors=3, + rgb_range=255, + kernel_size=3, + reduction=16, + conv=default_conv, + ): + super(RCAN, self).__init__() + self.scale = scale + act = nn.ReLU() + + n_resgroups = n_resgroups + n_resblocks = n_resblocks + n_feats = n_feats + kernel_size = kernel_size + reduction = reduction + scale = scale + act = nn.ReLU() + + rgb_mean = (0.4488, 0.4371, 0.4040) + rgb_std = (1.0, 1.0, 1.0) + self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale= 1, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats, n_colors, kernel_size) + ] + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.tail = nn.Sequential(*modules_tail) + + self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + def forward(self, x): + x = self.sub_mean(x) + x = self.head(x) + + res = self.body(x) + res += x + + x = self.tail(res) + x = self.add_mean(x) + + return x diff --git a/ppgan/models/rcan_model.py b/ppgan/models/rcan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..45061f4a7d913c815166af6abd9b9ec96d74b527 --- /dev/null +++ b/ppgan/models/rcan_model.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .generators.builder import build_generator +from .criterions.builder import build_criterion +from .base_model import BaseModel +from .builder import MODELS +from ..utils.visual import tensor2img +from ..modules.init import reset_parameters + + +@MODELS.register() +class RCANModel(BaseModel): + """Base SR model for single image super-resolution. + """ + + def __init__(self, generator, pixel_criterion=None, use_init_weight=False): + """ + Args: + generator (dict): config of generator. + pixel_criterion (dict): config of pixel criterion. + """ + super(RCANModel, self).__init__() + + self.nets['generator'] = build_generator(generator) + self.error_last = 1e8 + self.batch = 0 + if pixel_criterion: + self.pixel_criterion = build_criterion(pixel_criterion) + if use_init_weight: + init_sr_weight(self.nets['generator']) + + def setup_input(self, input): + self.lq = paddle.to_tensor(input['lq']) + self.visual_items['lq'] = self.lq + if 'gt' in input: + self.gt = paddle.to_tensor(input['gt']) + self.visual_items['gt'] = self.gt + self.image_paths = input['lq_path'] + + def forward(self): + pass + + def train_iter(self, optims=None): + optims['optim'].clear_grad() + + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + # pixel loss + loss_pixel = self.pixel_criterion(self.output, self.gt) + self.losses['loss_pixel'] = loss_pixel + + skip_threshold = 1e6 + + if loss_pixel.item() < skip_threshold * self.error_last: + loss_pixel.backward() + optims['optim'].step() + else: + print('Skip this batch {}! (Loss: {})'.format( + self.batch + 1, loss_pixel.item())) + self.batch += 1 + + if self.batch % 1000 == 0: + self.error_last = loss_pixel.item() / 1000 + print("update error_last:{}".format(self.error_last)) + + def test_iter(self, metrics=None): + self.nets['generator'].eval() + with paddle.no_grad(): + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + self.nets['generator'].train() + + out_img = [] + gt_img = [] + for out_tensor, gt_tensor in zip(self.output, self.gt): + out_img.append(tensor2img(out_tensor, (0., 255.))) + gt_img.append(tensor2img(gt_tensor, (0., 255.))) + + if metrics is not None: + for metric in metrics.values(): + metric.update(out_img, gt_img) + + +def init_sr_weight(net): + + def reset_func(m): + if hasattr(m, 'weight') and (not isinstance( + m, (nn.BatchNorm, nn.BatchNorm2D))): + reset_parameters(m) + + net.apply(reset_func)