未验证 提交 349bcb10 编写于 作者: K kongdebug 提交者: GitHub

【论文复现赛】SwinIR (#692)

* add SwinIR model

* [fix] add swinir

* [fix] fix url bug

* [fix] fix something for pr

* [feature] add README.md for SwinIR

* [fix] fix the testsets path

* fix testsets path again

* [fix] fix test data path
上级 b2f18921
......@@ -119,6 +119,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* [MPR Net](./docs/en_US/tutorials/mpr_net.md)
* [FaceEnhancement](./docs/en_US/tutorials/face_enhancement.md)
* [PReNet](./docs/en_US/tutorials/prenet.md)
* [SwinIR](./docs/en_US/tutorials/swinir.md)
## Composite Application
......
......@@ -138,7 +138,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
* 视频超分:[Video Super Resolution(VSR)](./docs/zh_CN/tutorials/video_super_resolution.md)
* 包含模型:⭐ PP-MSVSR ⭐、EDVR、BasicVSR、BasicVSR++
* 图像视频修复
* 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)
* 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)[SwinIR](./docs/zh_CN/tutorials/swinir.md)
* 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md)
* 图像去雨:[PReNet](./docs/zh_CN/tutorials/prenet.md)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import argparse
import paddle
from ppgan.apps import SwinIRPredictor
sys.path.insert(0, os.getcwd())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--seed",
type=int,
default=None,
help="sample random seed for model's image generation")
parser.add_argument('--images_path',
default=None,
required=True,
type=str,
help='Single image or images directory.')
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = SwinIRPredictor(output_path=args.output_path,
weight_path=args.weight_path,
seed=args.seed)
predictor.run(images_path=args.images_path)
total_iters: 6400000
output_dir: output_dir
model:
name: SwinIRModel
generator:
name: SwinIR
upscale: 1
img_size: 128
window_size: 8
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
char_criterion:
name: CharbonnierLoss
eps: 0.000000001
reduction: mean
dataset:
train:
name: SwinIRDataset
num_workers: 8
batch_size: 2 # 1GPU
opt:
phase: train
n_channels: 3
H_size: 128
sigma: 15
sigma_test: 15
dataroot_H: data/trainsets/trainH
test:
name: SwinIRDataset
num_workers: 1
batch_size: 1
opt:
phase: test
n_channels: 3
H_size: 128
sigma: 15
sigma_test: 15
dataroot_H: data/trainsets/CBSD68
export_model:
- {name: 'generator', inputs_num: 1}
lr_scheduler:
name: MultiStepDecay
learning_rate: 5e-5 # num_gpu * 5e-5
milestones: [3200000, 4800000, 5600000, 6000000, 6400000]
gamma: 0.5
validate:
interval: 200
save_img: True
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 500
English | [Chinese](../../zh_CN/tutorials/swinir.md)
## SwinIR Strong Baseline Model for Image Restoration Based on Swin Transformer
## 1、Introduction
The structure of SwinIR is relatively simple. If you have seen Swin-Transformer, there is no difficulty. The authors introduce the Swin-T structure for low-level vision tasks, including image super-resolution reconstruction, image denoising, and image compression artifact removal. The SwinIR network consists of a shallow feature extraction module, a deep feature extraction module and a reconstruction module. The reconstruction module uses different structures for different tasks. Shallow feature extraction is a 3×3 convolutional layer. Deep feature extraction is composed of k RSTB blocks and a convolutional layer plus residual connections. Each RSTB (Res-Swin-Transformer-Block) consists of L STLs and a layer of convolution plus residual connections. The structure of the model is shown in the following figure:
![](https://ai-studio-static-online.cdn.bcebos.com/b550e84915634951af756a545c643c815001be73372248b0b5179fd1652ae003)
For a more detailed introduction to the model, please refer to the original paper [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf), PaddleGAN currently provides the weight of the denoising task.
## 2 How to use
### 2.1 Quick start
After installing PaddleGAN, you can run a command as follows to generate the restorated image.
```sh
python applications/tools/swinir_denoising.py --images_path ${PATH_OF_IMAGE}
```
Where `PATH_OF_IMAGE` is the path of the image you need to denoise, or the path of the folder where the images is located.
### 2.2 Prepare dataset
#### Train Dataset
[DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (800 training images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) + [BSD500](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz) (400 training&testing images) + [WED](http://ivc.uwaterloo.ca/database/WaterlooExploration/exploration_database_and_code.rar)(4744 images)
The data that has been sorted out: put it in [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/149405).
The training data is placed under: `data/trainsets/trainH`
#### Test Dataset
The test data is CBSD68: put it in [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/147756).
Extract to: `data/triansets/CBSD68`
### 2.3 Training
An example is training to denoising. If you want to train for other tasks,If you want to train other tasks, you can change the dataset and modify the config file.
```sh
python -u tools/main.py --config-file configs/swinir_denoising.yaml
```
### 2.4 Test
test model:
```sh
python tools/main.py --config-file configs/swinir_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
Denoising
| model | dataset | PSNR/SSIM |
|---|---|---|
| SwinIR | CBSD68 | 36.0819 / 0.9464 |
## 4 Download
| model | link |
|---|---|
| SwinIR| [SwinIR_Denoising](https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams) |
# References
- [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf)
```
@article{liang2021swinir,
title={SwinIR: Image Restoration Using Swin Transformer},
author={Liang, Jingyun and Cao, Jiezhang and Sun, Guolei and Zhang, Kai and Van Gool, Luc and Timofte, Radu},
journal={arXiv preprint arXiv:2108.10257},
year={2021}
}
```
[English](../../en_US/tutorials/swinir.md) | 中文
## SwinIR 基于Swin Transformer的用于图像恢复的强基线模型
## 1、简介
SwinIR的结构比较简单,如果看过Swin-Transformer的话就没什么难点了。作者引入Swin-T结构应用于低级视觉任务,包括图像超分辨率重建、图像去噪、图像压缩伪影去除。SwinIR网络由一个浅层特征提取模块、深层特征提取模块、重建模块构成。重建模块对不同的任务使用不同的结构。浅层特征提取就是一个3×3的卷积层。深层特征提取是k个RSTB块和一个卷积层加残差连接构成。每个RSTB(Res-Swin-Transformer-Block)由L个STL和一层卷积加残差连接构成。模型的结构如下图所示:
![](https://ai-studio-static-online.cdn.bcebos.com/b550e84915634951af756a545c643c815001be73372248b0b5179fd1652ae003)
对模型更详细的介绍,可参考论文原文[SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf),PaddleGAN中目前提供去噪任务的权重
## 2 如何使用
### 2.1 快速体验
安装`PaddleGAN`之后进入`PaddleGAN`文件夹下,运行如下命令即生成修复后的图像`./output_dir/Denoising/image_name.png`
```sh
python applications/tools/swinir_denoising.py --images_path ${PATH_OF_IMAGE}
```
其中`PATH_OF_IMAGE`为你需要去噪的图像路径,或图像所在文件夹的路径
### 2.2 数据准备
#### 训练数据
[DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (800 training images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) + [BSD500](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz) (400 training&testing images) + [WED](http://ivc.uwaterloo.ca/database/WaterlooExploration/exploration_database_and_code.rar)(4744 images)
已经整理好的数据:放在了 [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/149405) 里.
训练数据放在:`data/trainsets/trainH`
#### 测试数据
测试数据为 CBSD68:放在了 [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/147756) 里.
解压到:`data/triansets/CBSD68`
- 经过处理之后,`PaddleGAN/data`文件夹下的
```sh
trainsets
├── trainH
| |-- 101085.png
| |-- 101086.png
| |-- ......
│ └── 201085.png
└── CBSD68
├── 271035.png
|-- ......
└── 351093.png
```
### 2.3 训练
示例以训练Denoising的数据为例。如果想训练其他任务可以更换数据集并修改配置文件
```sh
python -u tools/main.py --config-file configs/swinir_denoising.yaml
```
### 2.4 测试
测试模型:
```sh
python tools/main.py --config-file configs/swinir_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 结果展示
去噪
| 模型 | 数据集 | PSNR/SSIM |
|---|---|---|
| SwinIR | CBSD68 | 36.0819 / 0.9464 |
## 4 模型下载
| 模型 | 下载地址 |
|---|---|
| SwinIR| [SwinIR_Denoising](https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams) |
# 参考文献
- [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf)
```
@article{liang2021swinir,
title={SwinIR: Image Restoration Using Swin Transformer},
author={Liang, Jingyun and Cao, Jiezhang and Sun, Guolei and Zhang, Kai and Van Gool, Luc and Timofte, Radu},
journal={arXiv preprint arXiv:2108.10257},
year={2021}
}
```
......@@ -37,3 +37,4 @@ from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \
PPMSVSRLargePredictor)
from .singan_predictor import SinGANPredictor
from .gpen_predictor import GPENPredictor
from .swinir_predictor import SwinIRPredictor
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
from glob import glob
from natsort import natsorted
import numpy as np
import os
import random
from tqdm import tqdm
import paddle
from ppgan.models.generators import SwinIR
from ppgan.utils.download import get_path_from_url
from .base_predictor import BasePredictor
model_cfgs = {
'Denoising': {
'model_urls':
'https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams',
'upscale': 1,
'img_size': 128,
'window_size': 8,
'depths': [6, 6, 6, 6, 6, 6],
'embed_dim': 180,
'num_heads': [6, 6, 6, 6, 6, 6],
'mlp_ratio': 2
}
}
class SwinIRPredictor(BasePredictor):
def __init__(self,
output_path='output_dir',
weight_path=None,
seed=None,
window_size=8):
self.output_path = output_path
task = 'Denoising'
self.task = task
self.window_size = window_size
if weight_path is None:
if task in model_cfgs.keys():
weight_path = get_path_from_url(model_cfgs[task]['model_urls'])
checkpoint = paddle.load(weight_path)
else:
raise ValueError('Predictor need a task to define!')
else:
if weight_path.startswith("http"): # os.path.islink dosen't work!
weight_path = get_path_from_url(weight_path)
checkpoint = paddle.load(weight_path)
else:
checkpoint = paddle.load(weight_path)
self.generator = SwinIR(upscale=model_cfgs[task]['upscale'],
img_size=model_cfgs[task]['img_size'],
window_size=model_cfgs[task]['window_size'],
depths=model_cfgs[task]['depths'],
embed_dim=model_cfgs[task]['embed_dim'],
num_heads=model_cfgs[task]['num_heads'],
mlp_ratio=model_cfgs[task]['mlp_ratio'])
checkpoint = checkpoint['generator']
self.generator.set_state_dict(checkpoint)
self.generator.eval()
if seed is not None:
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_images(self, images_path):
if os.path.isdir(images_path):
return natsorted(
glob(os.path.join(images_path, '*.jpeg')) +
glob(os.path.join(images_path, '*.jpg')) +
glob(os.path.join(images_path, '*.JPG')) +
glob(os.path.join(images_path, '*.png')) +
glob(os.path.join(images_path, '*.PNG')))
else:
return [images_path]
def imread_uint(self, path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
img = np.expand_dims(img, axis=2) # HxWx1
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return img
def uint2single(self, img):
return np.float32(img / 255.)
# convert single (HxWxC) to 3-dimensional paddle tensor
def single2tensor3(self, img):
return paddle.Tensor(np.ascontiguousarray(
img, dtype=np.float32)).transpose([2, 0, 1])
def run(self, images_path=None):
os.makedirs(self.output_path, exist_ok=True)
task_path = os.path.join(self.output_path, self.task)
os.makedirs(task_path, exist_ok=True)
image_files = self.get_images(images_path)
for image_file in tqdm(image_files):
img_L = self.imread_uint(image_file, 3)
image_name = os.path.basename(image_file)
img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(task_path, image_name), img)
tmps = image_name.split('.')
assert len(
tmps) == 2, f'Invalid image name: {image_name}, too much "."'
restoration_save_path = os.path.join(
task_path, f'{tmps[0]}_restoration.{tmps[1]}')
img_L = self.uint2single(img_L)
# HWC to CHW, numpy to tensor
img_L = self.single2tensor3(img_L)
img_L = img_L.unsqueeze(0)
with paddle.no_grad():
# pad input image to be a multiple of window_size
_, _, h_old, w_old = img_L.shape
h_pad = (h_old // self.window_size +
1) * self.window_size - h_old
w_pad = (w_old // self.window_size +
1) * self.window_size - w_old
img_L = paddle.concat([img_L, paddle.flip(img_L, [2])],
2)[:, :, :h_old + h_pad, :]
img_L = paddle.concat([img_L, paddle.flip(img_L, [3])],
3)[:, :, :, :w_old + w_pad]
output = self.generator(img_L)
output = output[..., :h_old, :w_old]
restored = paddle.clip(output, 0, 1)
restored = restored.numpy()
restored = restored.transpose(0, 2, 3, 1)
restored = restored[0]
restored = restored * 255
restored = restored.astype(np.uint8)
cv2.imwrite(restoration_save_path,
cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
print('Done, output path is:', task_path)
......@@ -31,3 +31,4 @@ from .vsr_folder_dataset import VSRFolderDataset
from .photopen_dataset import PhotoPenDataset
from .empty_dataset import EmptyDataset
from .gpen_dataset import GPENDataset
from .swinir_dataset import SwinIRDataset
# code was heavily based on https://github.com/cszn/KAIR
# MIT License
# Copyright (c) 2019 Kai Zhang
import os
import random
import numpy as np
import cv2
import paddle
from paddle.io import Dataset
from .builder import DATASETS
def is_image_file(filename):
return any(
filename.endswith(extension)
for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
def get_image_paths(dataroot):
paths = None # return None if dataroot is None
if isinstance(dataroot, str):
paths = sorted(_get_paths_from_images(dataroot))
elif isinstance(dataroot, list):
paths = []
for i in dataroot:
paths += sorted(_get_paths_from_images(i))
return paths
def _get_paths_from_images(path):
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return images
def imread_uint(path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
img = np.expand_dims(img, axis=2) # HxWx1
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return img
def augment_img(img, mode=0):
if mode == 0:
return img
elif mode == 1:
return np.flipud(np.rot90(img))
elif mode == 2:
return np.flipud(img)
elif mode == 3:
return np.rot90(img, k=3)
elif mode == 4:
return np.flipud(np.rot90(img, k=2))
elif mode == 5:
return np.rot90(img)
elif mode == 6:
return np.rot90(img, k=2)
elif mode == 7:
return np.flipud(np.rot90(img, k=3))
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return paddle.Tensor(np.ascontiguousarray(img, dtype=np.float32)).transpose(
[2, 0, 1]) / 255.
def uint2single(img):
return np.float32(img / 255.)
# convert single (HxWxC) to 3-dimensional paddle tensor
def single2tensor3(img):
return paddle.Tensor(np.ascontiguousarray(img, dtype=np.float32)).transpose(
[2, 0, 1])
@DATASETS.register()
class SwinIRDataset(Dataset):
""" Get L/H for denosing on AWGN with fixed sigma.
Ref:
DnCNN: Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising
Args:
opt (dict): A dictionary defining dataset-related parameters.
"""
def __init__(self, opt=None):
super(SwinIRDataset, self).__init__()
print(
'Dataset: Denosing on AWGN with fixed sigma. Only dataroot_H is needed.'
)
self.opt = opt
self.n_channels = opt['n_channels'] if opt['n_channels'] else 3
self.patch_size = opt['H_size'] if opt['H_size'] else 64
self.sigma = opt['sigma'] if opt['sigma'] else 25
self.sigma_test = opt['sigma_test'] if opt['sigma_test'] else self.sigma
self.paths_H = get_image_paths(opt['dataroot_H'])
def __len__(self):
return len(self.paths_H)
def __getitem__(self, index):
# get H image
H_path = self.paths_H[index]
img_H = imread_uint(H_path, self.n_channels)
L_path = H_path
if self.opt['phase'] == 'train':
# get L/H patch pairs
H, W, _ = img_H.shape
# randomly crop the patch
rnd_h = random.randint(0, max(0, H - self.patch_size))
rnd_w = random.randint(0, max(0, W - self.patch_size))
patch_H = img_H[rnd_h:rnd_h + self.patch_size,
rnd_w:rnd_w + self.patch_size, :]
# augmentation - flip, rotate
mode = random.randint(0, 7)
patch_H = augment_img(patch_H, mode=mode)
img_H = uint2tensor3(patch_H)
img_L = img_H.clone()
# add noise
noise = paddle.randn(img_L.shape) * self.sigma / 255.0
img_L = img_L + noise
else:
# get L/H image pairs
img_H = uint2single(img_H)
img_L = np.copy(img_H)
# add noise
np.random.seed(seed=0)
img_L += np.random.normal(0, self.sigma_test / 255.0, img_L.shape)
# HWC to CHW, numpy to tensor
img_L = single2tensor3(img_L)
img_H = single2tensor3(img_H)
filename = os.path.splitext(os.path.split(H_path)[-1])[0]
return img_H, img_L, filename
......@@ -38,3 +38,4 @@ from .singan_model import SinGANModel
from .rcan_model import RCANModel
from .prenet_model import PReNetModel
from .gpen_model import GPENModel
from .swinir_model import SwinIRModel
......@@ -42,3 +42,4 @@ from .generator_singan import SinGANGenerator
from .rcan import RCAN
from .prenet import PReNet
from .gpen import GPEN
from .swinir import SwinIR
# code was heavily based on https://github.com/cszn/KAIR
# MIT License
# Copyright (c) 2019 Kai Zhang
"""
Droppath, reimplement from https://github.com/yueatsprograms/Stochastic_Depth
"""
from itertools import repeat
import collections.abc
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
class DropPath(nn.Layer):
"""DropPath class"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def drop_path(self, inputs):
"""drop path op
Args:
input: tensor with arbitrary shape
drop_prob: float number of drop path probability, default: 0.0
training: bool, if current mode is training, default: False
Returns:
output: output tensor after drop path
"""
# if prob is 0 or eval mode, return original input
if self.drop_prob == 0. or not self.training:
return inputs
keep_prob = 1 - self.drop_prob
keep_prob = paddle.to_tensor(keep_prob, dtype='float32')
shape = (
inputs.shape[0], ) + (1, ) * (inputs.ndim - 1) # shape=(N, 1, 1, 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype)
random_tensor = random_tensor.floor() # mask
output = inputs.divide(
keep_prob
) * random_tensor # divide is to keep same output expectation
return output
def forward(self, inputs):
return self.drop_path(inputs)
to_2tuple = _ntuple(2)
@paddle.jit.not_to_static
def swapdim(x, dim1, dim2):
a = list(range(len(x.shape)))
a[dim1], a[dim2] = a[dim2], a[dim1]
return x.transpose(a)
class Identity(nn.Layer):
""" Identity layer
The output of this layer is the input without any change.
Use this layer to avoid if condition in some forward methods
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Mlp(nn.Layer):
def __init__(self, in_features, hidden_features, dropout):
super(Mlp, self).__init__()
w_attr_1, b_attr_1 = self._init_weights()
self.fc1 = nn.Linear(in_features,
hidden_features,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
w_attr_2, b_attr_2 = self._init_weights()
self.fc2 = nn.Linear(hidden_features,
in_features,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def _init_weights(self):
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierUniform())
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Normal(
std=1e-6))
return weight_attr, bias_attr
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class WindowAttention(nn.Layer):
"""Window based multihead attention, with relative position bias.
Both shifted window and non-shifted window are supported.
Args:
dim (int): input dimension (channels)
window_size (int): height and width of the window
num_heads (int): number of attention heads
qkv_bias (bool): if True, enable learnable bias to q,k,v, default: True
qk_scale (float): override default qk scale head_dim**-0.5 if set, default: None
attention_dropout (float): dropout of attention
dropout (float): dropout for output
"""
def __init__(self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attention_dropout=0.,
dropout=0.):
super(WindowAttention, self).__init__()
self.window_size = window_size
self.num_heads = num_heads
self.dim = dim
self.dim_head = dim // num_heads
self.scale = qk_scale or self.dim_head**-0.5
self.relative_position_bias_table = paddle.create_parameter(
shape=[(2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads],
dtype='float32',
default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
weight_attr, bias_attr = self._init_weights()
# relative position index for each token inside window
coords_h = paddle.arange(self.window_size[0])
coords_w = paddle.arange(self.window_size[1])
coords = paddle.stack(paddle.meshgrid([coords_h, coords_w
])) # [2, window_h, window_w]
coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w]
# 2, window_h * window_w, window_h * window_h
relative_coords = coords_flatten.unsqueeze(
2) - coords_flatten.unsqueeze(1)
# winwod_h*window_w, window_h*window_w, 2
relative_coords = relative_coords.transpose([1, 2, 0])
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
# [window_size * window_size, window_size*window_size]
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim,
dim * 3,
weight_attr=weight_attr,
bias_attr=bias_attr if qkv_bias else False)
self.attn_dropout = nn.Dropout(attention_dropout)
self.proj = nn.Linear(dim,
dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.proj_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(axis=-1)
def transpose_multihead(self, x):
tensor_shape = list(x.shape[:-1])
new_shape = tensor_shape + [self.num_heads, self.dim_head]
x = x.reshape(new_shape)
x = x.transpose([0, 2, 1, 3])
return x
def get_relative_pos_bias_from_pos_index(self):
# relative_position_bias_table is a ParamBase object
# https://github.com/PaddlePaddle/Paddle/blob/067f558c59b34dd6d8626aad73e9943cf7f5960f/python/paddle/fluid/framework.py#L5727
table = self.relative_position_bias_table # N x num_heads
# index is a tensor
index = self.relative_position_index.reshape(
[-1]) # window_h*window_w * window_h*window_w
# NOTE: paddle does NOT support indexing Tensor by a Tensor
relative_position_bias = paddle.index_select(x=table, index=index)
return relative_position_bias
def forward(self, x, mask=None):
qkv = self.qkv(x).chunk(3, axis=-1)
q, k, v = map(self.transpose_multihead, qkv)
q = q * self.scale
attn = paddle.matmul(q, k, transpose_y=True)
relative_position_bias = self.get_relative_pos_bias_from_pos_index()
relative_position_bias = relative_position_bias.reshape([
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1], -1
])
# nH, window_h*window_w, window_h*window_w
relative_position_bias = relative_position_bias.transpose([2, 0, 1])
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.reshape(
[x.shape[0] // nW, nW, self.num_heads, x.shape[1], x.shape[1]])
attn += mask.unsqueeze(1).unsqueeze(0)
attn = attn.reshape([-1, self.num_heads, x.shape[1], x.shape[1]])
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_dropout(attn)
z = paddle.matmul(attn, v)
z = z.transpose([0, 2, 1, 3])
tensor_shape = list(z.shape[:-2])
new_shape = tensor_shape + [self.dim]
z = z.reshape(new_shape)
z = self.proj(z)
z = self.proj_dropout(z)
return z
def _init_weights(self):
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0))
return weight_attr, bias_attr
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
flops += N * self.dim * 3 * self.dim
flops += self.num_heads * N * (self.dim // self.num_heads) * N
flops += self.num_heads * N * N * (self.dim // self.num_heads)
flops += N * self.dim * self.dim
return flops
def windows_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.reshape(
[B, H // window_size, window_size, W // window_size, window_size, C])
windows = x.transpose([0, 1, 3, 2, 4,
5]).reshape([-1, window_size, window_size, C])
return windows
def windows_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape(
[B, H // window_size, W // window_size, window_size, window_size, -1])
x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])
return x
class SwinTransformerBlock(nn.Layer):
"""Swin transformer block
Contains window multi head self attention, droppath, mlp, norm and residual.
Attributes:
dim: int, input dimension (channels)
input_resolution: int, input resoultion
num_heads: int, number of attention heads
windos_size: int, window size, default: 7
shift_size: int, shift size for SW-MSA, default: 0
mlp_ratio: float, ratio of mlp hidden dim and input embedding dim, default: 4.
qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True
qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None
dropout: float, dropout for output, default: 0.
attention_dropout: float, dropout of attention, default: 0.
droppath: float, drop path rate, default: 0.
"""
def __init__(self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
dropout=0.,
attention_dropout=0.,
droppath=0.):
super(SwinTransformerBlock, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attention_dropout=attention_dropout,
dropout=dropout)
self.drop_path = DropPath(droppath) if droppath > 0. else Identity()
self.norm2 = nn.LayerNorm(dim)
self.mlp = Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
dropout=dropout)
attn_mask = self.calculate_mask(self.input_resolution)
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = paddle.zeros((1, H, W, 1))
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = windows_partition(
img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.reshape(
[-1, self.window_size * self.window_size])
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
huns = -100.0 * paddle.ones_like(attn_mask)
attn_mask = huns * (attn_mask != 0).astype("float32")
return attn_mask
else:
return None
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.reshape([B, H, W, C])
# cyclic shift
if self.shift_size > 0:
shifted_x = paddle.roll(x,
shifts=(-self.shift_size, -self.shift_size),
axis=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = windows_partition(shifted_x, self.window_size)
x_windows = x_windows.reshape(
[-1, self.window_size * self.window_size, C])
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask)
else:
attn_windows = self.attn(x_windows,
mask=self.calculate_mask(x_size))
# merge windows
attn_windows = attn_windows.reshape(
[-1, self.window_size, self.window_size, C])
shifted_x = windows_reverse(attn_windows, self.window_size, H, W)
# reverse cyclic shift
if self.shift_size > 0:
x = paddle.roll(shifted_x,
shifts=(self.shift_size, self.shift_size),
axis=(1, 2))
else:
x = shifted_x
x = x.reshape([B, H * W, C])
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Layer):
""" Patch Merging class
Merge multiple patch into one path and keep the out dim.
Spefically, merge adjacent 2x2 patches(dim=C) into 1 patch.
The concat dim 4*C is rescaled to 2*C
Args:
input_resolution (tuple | ints): the size of input
dim: dimension of single patch
reduction: nn.Linear which maps 4C to 2C dim
norm: nn.LayerNorm, applied after linear layer.
"""
def __init__(self, input_resolution, dim):
super(PatchMerging, self).__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
self.norm = nn.LayerNorm(4 * dim)
def forward(self, x):
h, w = self.input_resolution
b, _, c = x.shape
x = x.reshape([b, h, w, c])
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = paddle.concat([x0, x1, x2, x3], -1) #[B, H/2, W/2, 4*C]
x = x.reshape([b, -1, 4 * c]) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Layer):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
dropout (float, optional): Dropout rate. Default: 0.0
attention_dropout (float, optional): Attention dropout rate. Default: 0.0
droppath (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
dropout=0.,
attention_dropout=0.,
droppath=0.,
downsample=None):
super(BasicLayer, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = nn.LayerList()
for i in range(depth):
self.blocks.append(
SwinTransformerBlock(dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if
(i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout=dropout,
attention_dropout=attention_dropout,
droppath=droppath[i] if isinstance(
droppath, list) else droppath))
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim)
else:
self.downsample = None
def forward(self, x, x_size):
for block in self.blocks:
x = block(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Layer):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
downsample=None,
img_size=224,
patch_size=4,
resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout=drop,
attention_dropout=attn_drop,
droppath=drop_path,
downsample=downsample)
if resi_connection == '1conv':
self.conv = nn.Conv2D(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2D(dim, dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2D(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2D(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(img_size=img_size,
patch_size=patch_size,
in_chans=0,
embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(img_size=img_size,
patch_size=patch_size,
in_chans=0,
embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(
self.conv(self.patch_unembed(self.residual_group(x, x_size),
x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Layer):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Layer, optional): Normalization layer. Default: None
"""
def __init__(self,
img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Layer):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Layer, optional): Normalization layer. Default: None
"""
def __init__(self,
img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose([0, 2,
1]).reshape([B, self.embed_dim, x_size[0],
x_size[1]]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2D(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2D(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. '
'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2D(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
@GENERATORS.register()
class SwinIR(nn.Layer):
r""" SwinIR
A Pypaddle impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self,
img_size=64,
patch_size=1,
in_chans=3,
embed_dim=96,
depths=[6, 6, 6, 6],
num_heads=[6, 6, 6, 6],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
upscale=2,
img_range=1.,
upsampler='',
resi_connection='1conv'):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = np.array([0.4488, 0.4371, 0.4040], dtype=np.float32)
self.mean = paddle.Tensor(rgb_mean).reshape([1, 3, 1, 1])
else:
self.mean = paddle.zeros([1., 1., 1., 1.], dtype=paddle.float32)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
# 1. shallow feature extraction
self.conv_first = nn.Conv2D(num_in_ch, embed_dim, 3, 1, 1)
# 2. deep feature extraction
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = paddle.nn.ParameterList([
paddle.create_parameter(
shape=[1, num_patches, embed_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.TruncatedNormal(
std=.02))
])
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.LayerList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]
):sum(depths[:i_layer +
1])], # no impact on SR results
downsample=None,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2D(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(
nn.Conv2D(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2D(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2D(embed_dim // 4, embed_dim, 3, 1, 1))
# 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(
nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU())
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(
upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
assert self.upscale == 4, 'only support x4 now.'
self.conv_before_upsample = nn.Sequential(
nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU())
self.conv_up1 = nn.Conv2D(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2D(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2D(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2D(embed_dim, num_out_ch, 3, 1, 1)
def no_weight_decay(self):
return {'absolute_pos_embed'}
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def check_image_size(self, x):
_, _, h, w = x.shape
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(
self.conv_up1(
paddle.nn.functional.interpolate(x,
scale_factor=2,
mode='nearest')))
x = self.lrelu(
self.conv_up2(
paddle.nn.functional.interpolate(x,
scale_factor=2,
mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, :H * self.upscale, :W * self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
from .builder import MODELS
from .base_model import BaseModel
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from ppgan.utils.visual import tensor2img
@MODELS.register()
class SwinIRModel(BaseModel):
"""SwinIR Model.
SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
Originally Written by Ze Liu, Modified by Jingyun Liang.
"""
def __init__(self, generator, char_criterion=None):
"""Initialize the MPR class.
Args:
generator (dict): config of generator.
char_criterion (dict): config of char criterion.
"""
super(SwinIRModel, self).__init__(generator)
self.current_iter = 1
self.nets['generator'] = build_generator(generator)
if char_criterion:
self.char_criterion = build_criterion(char_criterion)
def setup_input(self, input):
self.target = input[0]
self.lq = input[1]
def train_iter(self, optims=None):
optims['optim'].clear_gradients()
restored = self.nets['generator'](self.lq)
loss = self.char_criterion(restored, self.target)
loss.backward()
optims['optim'].step()
self.losses['loss'] = loss.numpy()
def forward(self):
pass
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.target):
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
def export_model(self,
export_model=None,
output_dir=None,
inputs_size=None,
export_serving_model=False,
model_name=None):
shape = inputs_size[0]
new_model = self.nets['generator']
new_model.eval()
input_spec = [paddle.static.InputSpec(shape=shape, dtype="float32")]
static_model = paddle.jit.to_static(new_model, input_spec=input_spec)
if output_dir is None:
output_dir = 'inference_model'
if model_name is None:
model_name = '{}_{}'.format(self.__class__.__name__.lower(),
export_model[0]['name'])
paddle.jit.save(static_model, os.path.join(output_dir, model_name))
===========================train_params===========================
model_name:swinir
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10
output_dir:./output/
snapshot_config.interval:lite_train_lite_infer=10
pretrained_model:null
train_model_name:swinir*/*checkpoint.pdparams
train_infer_img_dir:null
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/swinir_denoising.yaml --seed 100 -o log_config.interval=1
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/swinir_denoising.yaml --inputs_size=1,3,128,128 --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/swinir/swinirmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type swinir --seed 100 -c configs/swinir_denoising.yaml --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
......@@ -62,6 +62,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
rm -rf ./data/DIV2K*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14paddle.tar --no-check-certificate
cd ./data/ && tar xf DIV2KandSet14paddle.tar && cd ../ ;;
swinir)
rm -rf ./data/*sets
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/swinir_data.zip --no-check-certificate
cd ./data/ && unzip -q swinir_data.zip && cd ../ ;;
singan)
rm -rf ./data/singan*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate
......
......@@ -19,7 +19,7 @@ from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan"]
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan", "swinir"]
def parse_args():
......@@ -334,6 +334,63 @@ def main():
metric_file = os.path.join(args.output_path, "singan/metric.txt")
for metric in metrics.values():
metric.update(prediction, data['A'])
elif model_type == "swinir":
lq = data[1].numpy()
_, _, h_old, w_old = lq.shape
window_size = 8
tile = 128
tile_overlap = 32
# after feed data to model, shape of feature map is change
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
lq = np.concatenate([lq, np.flip(lq, 2)],
axis=2)[:, :, :h_old + h_pad, :]
lq = np.concatenate([lq, np.flip(lq, 3)],
axis=3)[:, :, :, :w_old + w_pad]
lq = lq.astype("float32")
b, c, h, w = lq.shape
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = 1 # scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = np.zeros([b, c, h * sf, w * sf], dtype=np.float32)
W = np.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = lq[..., h_idx:h_idx + tile, w_idx:w_idx + tile]
input_handles[0].copy_from_cpu(in_patch)
predictor.run()
out_patch = output_handle.copy_to_cpu()
out_patch_mask = np.ones_like(out_patch)
E[..., h_idx * sf:(h_idx + tile) * sf,
w_idx * sf:(w_idx + tile) * sf] += out_patch
W[..., h_idx * sf:(h_idx + tile) * sf,
w_idx * sf:(w_idx + tile) * sf] += out_patch_mask
output = np.true_divide(E, W)
prediction = output[..., :h_old * sf, :w_old * sf]
prediction = paddle.to_tensor(prediction)
target = tensor2img(data[0], (0., 1.))
prediction = tensor2img(prediction, (0., 1.))
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(prediction, target)
lq = tensor2img(data[1], (0., 1.))
sample_result = np.concatenate((lq, prediction, target), 1)
sample = cv2.cvtColor(sample_result, cv2.COLOR_RGB2BGR)
file_name = os.path.join(args.output_path, model_type,
"{}.png".format(i))
cv2.imwrite(file_name, sample)
if metrics:
log_file = open(metric_file, 'a')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册