From 349bcb10edf9af21fc85aedae592f5c196a99762 Mon Sep 17 00:00:00 2001 From: kongdebug <52785738+kongdebug@users.noreply.github.com> Date: Wed, 28 Sep 2022 21:31:32 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E8=AE=BA=E6=96=87=E5=A4=8D=E7=8E=B0?= =?UTF-8?q?=E8=B5=9B=E3=80=91SwinIR=20(#692)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add SwinIR model * [fix] add swinir * [fix] fix url bug * [fix] fix something for pr * [feature] add README.md for SwinIR * [fix] fix the testsets path * fix testsets path again * [fix] fix test data path --- README.md | 1 + README_cn.md | 2 +- applications/tools/swinir_denoising.py | 60 + configs/swinir_denoising.yaml | 82 ++ docs/en_US/tutorials/swinir.md | 81 ++ docs/zh_CN/tutorials/swinir.md | 101 ++ ppgan/apps/__init__.py | 1 + ppgan/apps/swinir_predictor.py | 169 +++ ppgan/datasets/__init__.py | 1 + ppgan/datasets/swinir_dataset.py | 165 +++ ppgan/models/__init__.py | 1 + ppgan/models/generators/__init__.py | 1 + ppgan/models/generators/swinir.py | 1060 +++++++++++++++++ ppgan/models/swinir_model.py | 102 ++ .../configs/swinir/train_infer_python.txt | 51 + test_tipc/prepare.sh | 4 + tools/inference.py | 59 +- 17 files changed, 1939 insertions(+), 2 deletions(-) create mode 100644 applications/tools/swinir_denoising.py create mode 100644 configs/swinir_denoising.yaml create mode 100644 docs/en_US/tutorials/swinir.md create mode 100644 docs/zh_CN/tutorials/swinir.md create mode 100644 ppgan/apps/swinir_predictor.py create mode 100644 ppgan/datasets/swinir_dataset.py create mode 100644 ppgan/models/generators/swinir.py create mode 100644 ppgan/models/swinir_model.py create mode 100644 test_tipc/configs/swinir/train_infer_python.txt diff --git a/README.md b/README.md index 51ae60e..770abd4 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional * [MPR Net](./docs/en_US/tutorials/mpr_net.md) * [FaceEnhancement](./docs/en_US/tutorials/face_enhancement.md) * [PReNet](./docs/en_US/tutorials/prenet.md) +* [SwinIR](./docs/en_US/tutorials/swinir.md) ## Composite Application diff --git a/README_cn.md b/README_cn.md index ffe6433..c4b60b9 100644 --- a/README_cn.md +++ b/README_cn.md @@ -138,7 +138,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆) * 视频超分:[Video Super Resolution(VSR)](./docs/zh_CN/tutorials/video_super_resolution.md) * 包含模型:⭐ PP-MSVSR ⭐、EDVR、BasicVSR、BasicVSR++ * 图像视频修复 - * 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md) + * 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)、[SwinIR](./docs/zh_CN/tutorials/swinir.md) * 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md) * 图像去雨:[PReNet](./docs/zh_CN/tutorials/prenet.md) diff --git a/applications/tools/swinir_denoising.py b/applications/tools/swinir_denoising.py new file mode 100644 index 0000000..de2f886 --- /dev/null +++ b/applications/tools/swinir_denoising.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import sys +import argparse + +import paddle +from ppgan.apps import SwinIRPredictor + +sys.path.insert(0, os.getcwd()) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_path", + type=str, + default='output_dir', + help="path to output image dir") + + parser.add_argument("--weight_path", + type=str, + default=None, + help="path to model checkpoint path") + + parser.add_argument("--seed", + type=int, + default=None, + help="sample random seed for model's image generation") + + parser.add_argument('--images_path', + default=None, + required=True, + type=str, + help='Single image or images directory.') + + parser.add_argument("--cpu", + dest="cpu", + action="store_true", + help="cpu mode.") + + args = parser.parse_args() + + if args.cpu: + paddle.set_device('cpu') + + predictor = SwinIRPredictor(output_path=args.output_path, + weight_path=args.weight_path, + seed=args.seed) + predictor.run(images_path=args.images_path) diff --git a/configs/swinir_denoising.yaml b/configs/swinir_denoising.yaml new file mode 100644 index 0000000..2ecb5e9 --- /dev/null +++ b/configs/swinir_denoising.yaml @@ -0,0 +1,82 @@ +total_iters: 6400000 +output_dir: output_dir + +model: + name: SwinIRModel + generator: + name: SwinIR + upscale: 1 + img_size: 128 + window_size: 8 + depths: [6, 6, 6, 6, 6, 6] + embed_dim: 180 + num_heads: [6, 6, 6, 6, 6, 6] + mlp_ratio: 2 + char_criterion: + name: CharbonnierLoss + eps: 0.000000001 + reduction: mean + +dataset: + train: + name: SwinIRDataset + num_workers: 8 + batch_size: 2 # 1GPU + opt: + phase: train + n_channels: 3 + H_size: 128 + sigma: 15 + sigma_test: 15 + dataroot_H: data/trainsets/trainH + test: + name: SwinIRDataset + num_workers: 1 + batch_size: 1 + opt: + phase: test + n_channels: 3 + H_size: 128 + sigma: 15 + sigma_test: 15 + dataroot_H: data/trainsets/CBSD68 + +export_model: + - {name: 'generator', inputs_num: 1} + +lr_scheduler: + name: MultiStepDecay + learning_rate: 5e-5 # num_gpu * 5e-5 + milestones: [3200000, 4800000, 5600000, 6000000, 6400000] + gamma: 0.5 + +validate: + interval: 200 + save_img: True + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 4 + test_y_channel: True + ssim: + name: SSIM + crop_border: 4 + test_y_channel: True + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + +log_config: + interval: 10 + visiual_interval: 5000 + +snapshot_config: + interval: 500 diff --git a/docs/en_US/tutorials/swinir.md b/docs/en_US/tutorials/swinir.md new file mode 100644 index 0000000..e7752db --- /dev/null +++ b/docs/en_US/tutorials/swinir.md @@ -0,0 +1,81 @@ +English | [Chinese](../../zh_CN/tutorials/swinir.md) + +## SwinIR Strong Baseline Model for Image Restoration Based on Swin Transformer + +## 1、Introduction + +The structure of SwinIR is relatively simple. If you have seen Swin-Transformer, there is no difficulty. The authors introduce the Swin-T structure for low-level vision tasks, including image super-resolution reconstruction, image denoising, and image compression artifact removal. The SwinIR network consists of a shallow feature extraction module, a deep feature extraction module and a reconstruction module. The reconstruction module uses different structures for different tasks. Shallow feature extraction is a 3×3 convolutional layer. Deep feature extraction is composed of k RSTB blocks and a convolutional layer plus residual connections. Each RSTB (Res-Swin-Transformer-Block) consists of L STLs and a layer of convolution plus residual connections. The structure of the model is shown in the following figure: + +![](https://ai-studio-static-online.cdn.bcebos.com/b550e84915634951af756a545c643c815001be73372248b0b5179fd1652ae003) + +For a more detailed introduction to the model, please refer to the original paper [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf), PaddleGAN currently provides the weight of the denoising task. + +## 2 How to use + +### 2.1 Quick start + +After installing PaddleGAN, you can run a command as follows to generate the restorated image. + +```sh +python applications/tools/swinir_denoising.py --images_path ${PATH_OF_IMAGE} +``` +Where `PATH_OF_IMAGE` is the path of the image you need to denoise, or the path of the folder where the images is located. + +### 2.2 Prepare dataset + +#### Train Dataset + +[DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (800 training images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) + [BSD500](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz) (400 training&testing images) + [WED](http://ivc.uwaterloo.ca/database/WaterlooExploration/exploration_database_and_code.rar)(4744 images) + +The data that has been sorted out: put it in [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/149405). + +The training data is placed under: `data/trainsets/trainH` + +#### Test Dataset + +The test data is CBSD68: put it in [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/147756). + +Extract to: `data/triansets/CBSD68` + +### 2.3 Training +An example is training to denoising. If you want to train for other tasks,If you want to train other tasks, you can change the dataset and modify the config file. + +```sh +python -u tools/main.py --config-file configs/swinir_denoising.yaml +``` + +### 2.4 Test + +test model: +```sh +python tools/main.py --config-file configs/swinir_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT} +``` + +## 3 Results +Denoising +| model | dataset | PSNR/SSIM | +|---|---|---| +| SwinIR | CBSD68 | 36.0819 / 0.9464 | + +## 4 Download + +| model | link | +|---|---| +| SwinIR| [SwinIR_Denoising](https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams) | + +# References + +- [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf) + +``` +@article{liang2021swinir, + title={SwinIR: Image Restoration Using Swin Transformer}, + author={Liang, Jingyun and Cao, Jiezhang and Sun, Guolei and Zhang, Kai and Van Gool, Luc and Timofte, Radu}, + journal={arXiv preprint arXiv:2108.10257}, + year={2021} +} +``` + + + + diff --git a/docs/zh_CN/tutorials/swinir.md b/docs/zh_CN/tutorials/swinir.md new file mode 100644 index 0000000..86f0f8e --- /dev/null +++ b/docs/zh_CN/tutorials/swinir.md @@ -0,0 +1,101 @@ +[English](../../en_US/tutorials/swinir.md) | 中文 + +## SwinIR 基于Swin Transformer的用于图像恢复的强基线模型 + + +## 1、简介 + +SwinIR的结构比较简单,如果看过Swin-Transformer的话就没什么难点了。作者引入Swin-T结构应用于低级视觉任务,包括图像超分辨率重建、图像去噪、图像压缩伪影去除。SwinIR网络由一个浅层特征提取模块、深层特征提取模块、重建模块构成。重建模块对不同的任务使用不同的结构。浅层特征提取就是一个3×3的卷积层。深层特征提取是k个RSTB块和一个卷积层加残差连接构成。每个RSTB(Res-Swin-Transformer-Block)由L个STL和一层卷积加残差连接构成。模型的结构如下图所示: + +![](https://ai-studio-static-online.cdn.bcebos.com/b550e84915634951af756a545c643c815001be73372248b0b5179fd1652ae003) + +对模型更详细的介绍,可参考论文原文[SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf),PaddleGAN中目前提供去噪任务的权重 + + + + +## 2 如何使用 + +### 2.1 快速体验 + +安装`PaddleGAN`之后进入`PaddleGAN`文件夹下,运行如下命令即生成修复后的图像`./output_dir/Denoising/image_name.png` + +```sh +python applications/tools/swinir_denoising.py --images_path ${PATH_OF_IMAGE} +``` +其中`PATH_OF_IMAGE`为你需要去噪的图像路径,或图像所在文件夹的路径 + +### 2.2 数据准备 + +#### 训练数据 + +[DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (800 training images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) + [BSD500](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz) (400 training&testing images) + [WED](http://ivc.uwaterloo.ca/database/WaterlooExploration/exploration_database_and_code.rar)(4744 images) + +已经整理好的数据:放在了 [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/149405) 里. + +训练数据放在:`data/trainsets/trainH` 下 + +#### 测试数据 + +测试数据为 CBSD68:放在了 [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/147756) 里. + +解压到:`data/triansets/CBSD68` + +- 经过处理之后,`PaddleGAN/data`文件夹下的 +```sh +trainsets +├── trainH +| |-- 101085.png +| |-- 101086.png +| |-- ...... +│ └── 201085.png +└── CBSD68 + ├── 271035.png + |-- ...... + └── 351093.png +``` + + + +### 2.3 训练 +示例以训练Denoising的数据为例。如果想训练其他任务可以更换数据集并修改配置文件 + +```sh +python -u tools/main.py --config-file configs/swinir_denoising.yaml +``` + +### 2.4 测试 + +测试模型: +```sh +python tools/main.py --config-file configs/swinir_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT} +``` + +## 3 结果展示 + +去噪 +| 模型 | 数据集 | PSNR/SSIM | +|---|---|---| +| SwinIR | CBSD68 | 36.0819 / 0.9464 | + + +## 4 模型下载 + +| 模型 | 下载地址 | +|---|---| +| SwinIR| [SwinIR_Denoising](https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams) | + + + +# 参考文献 + +- [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf) + +``` +@article{liang2021swinir, + title={SwinIR: Image Restoration Using Swin Transformer}, + author={Liang, Jingyun and Cao, Jiezhang and Sun, Guolei and Zhang, Kai and Van Gool, Luc and Timofte, Radu}, + journal={arXiv preprint arXiv:2108.10257}, + year={2021} +} +``` diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index 0f04747..61916cf 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -37,3 +37,4 @@ from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \ PPMSVSRLargePredictor) from .singan_predictor import SinGANPredictor from .gpen_predictor import GPENPredictor +from .swinir_predictor import SwinIRPredictor diff --git a/ppgan/apps/swinir_predictor.py b/ppgan/apps/swinir_predictor.py new file mode 100644 index 0000000..fc69a98 --- /dev/null +++ b/ppgan/apps/swinir_predictor.py @@ -0,0 +1,169 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +from glob import glob +from natsort import natsorted +import numpy as np +import os +import random +from tqdm import tqdm + +import paddle + +from ppgan.models.generators import SwinIR +from ppgan.utils.download import get_path_from_url +from .base_predictor import BasePredictor + +model_cfgs = { + 'Denoising': { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams', + 'upscale': 1, + 'img_size': 128, + 'window_size': 8, + 'depths': [6, 6, 6, 6, 6, 6], + 'embed_dim': 180, + 'num_heads': [6, 6, 6, 6, 6, 6], + 'mlp_ratio': 2 + } +} + + +class SwinIRPredictor(BasePredictor): + + def __init__(self, + output_path='output_dir', + weight_path=None, + seed=None, + window_size=8): + self.output_path = output_path + task = 'Denoising' + self.task = task + self.window_size = window_size + + if weight_path is None: + if task in model_cfgs.keys(): + weight_path = get_path_from_url(model_cfgs[task]['model_urls']) + checkpoint = paddle.load(weight_path) + else: + raise ValueError('Predictor need a task to define!') + else: + if weight_path.startswith("http"): # os.path.islink dosen't work! + weight_path = get_path_from_url(weight_path) + checkpoint = paddle.load(weight_path) + else: + checkpoint = paddle.load(weight_path) + + self.generator = SwinIR(upscale=model_cfgs[task]['upscale'], + img_size=model_cfgs[task]['img_size'], + window_size=model_cfgs[task]['window_size'], + depths=model_cfgs[task]['depths'], + embed_dim=model_cfgs[task]['embed_dim'], + num_heads=model_cfgs[task]['num_heads'], + mlp_ratio=model_cfgs[task]['mlp_ratio']) + + checkpoint = checkpoint['generator'] + self.generator.set_state_dict(checkpoint) + self.generator.eval() + + if seed is not None: + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + def get_images(self, images_path): + if os.path.isdir(images_path): + return natsorted( + glob(os.path.join(images_path, '*.jpeg')) + + glob(os.path.join(images_path, '*.jpg')) + + glob(os.path.join(images_path, '*.JPG')) + + glob(os.path.join(images_path, '*.png')) + + glob(os.path.join(images_path, '*.PNG'))) + else: + return [images_path] + + def imread_uint(self, path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + + return img + + def uint2single(self, img): + + return np.float32(img / 255.) + + # convert single (HxWxC) to 3-dimensional paddle tensor + def single2tensor3(self, img): + return paddle.Tensor(np.ascontiguousarray( + img, dtype=np.float32)).transpose([2, 0, 1]) + + def run(self, images_path=None): + os.makedirs(self.output_path, exist_ok=True) + task_path = os.path.join(self.output_path, self.task) + os.makedirs(task_path, exist_ok=True) + image_files = self.get_images(images_path) + for image_file in tqdm(image_files): + img_L = self.imread_uint(image_file, 3) + + image_name = os.path.basename(image_file) + img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(task_path, image_name), img) + + tmps = image_name.split('.') + assert len( + tmps) == 2, f'Invalid image name: {image_name}, too much "."' + restoration_save_path = os.path.join( + task_path, f'{tmps[0]}_restoration.{tmps[1]}') + + img_L = self.uint2single(img_L) + + # HWC to CHW, numpy to tensor + img_L = self.single2tensor3(img_L) + img_L = img_L.unsqueeze(0) + with paddle.no_grad(): + # pad input image to be a multiple of window_size + _, _, h_old, w_old = img_L.shape + h_pad = (h_old // self.window_size + + 1) * self.window_size - h_old + w_pad = (w_old // self.window_size + + 1) * self.window_size - w_old + img_L = paddle.concat([img_L, paddle.flip(img_L, [2])], + 2)[:, :, :h_old + h_pad, :] + img_L = paddle.concat([img_L, paddle.flip(img_L, [3])], + 3)[:, :, :, :w_old + w_pad] + output = self.generator(img_L) + output = output[..., :h_old, :w_old] + + restored = paddle.clip(output, 0, 1) + + restored = restored.numpy() + restored = restored.transpose(0, 2, 3, 1) + restored = restored[0] + restored = restored * 255 + restored = restored.astype(np.uint8) + + cv2.imwrite(restoration_save_path, + cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) + + print('Done, output path is:', task_path) diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index b0c2014..39c3237 100755 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -31,3 +31,4 @@ from .vsr_folder_dataset import VSRFolderDataset from .photopen_dataset import PhotoPenDataset from .empty_dataset import EmptyDataset from .gpen_dataset import GPENDataset +from .swinir_dataset import SwinIRDataset diff --git a/ppgan/datasets/swinir_dataset.py b/ppgan/datasets/swinir_dataset.py new file mode 100644 index 0000000..720ae6a --- /dev/null +++ b/ppgan/datasets/swinir_dataset.py @@ -0,0 +1,165 @@ +# code was heavily based on https://github.com/cszn/KAIR +# MIT License +# Copyright (c) 2019 Kai Zhang + +import os +import random +import numpy as np +import cv2 + +import paddle +from paddle.io import Dataset + +from .builder import DATASETS + + +def is_image_file(filename): + return any( + filename.endswith(extension) + for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if isinstance(dataroot, str): + paths = sorted(_get_paths_from_images(dataroot)) + elif isinstance(dataroot, list): + paths = [] + for i in dataroot: + paths += sorted(_get_paths_from_images(i)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +def augment_img(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return paddle.Tensor(np.ascontiguousarray(img, dtype=np.float32)).transpose( + [2, 0, 1]) / 255. + + +def uint2single(img): + + return np.float32(img / 255.) + + +# convert single (HxWxC) to 3-dimensional paddle tensor +def single2tensor3(img): + return paddle.Tensor(np.ascontiguousarray(img, dtype=np.float32)).transpose( + [2, 0, 1]) + + +@DATASETS.register() +class SwinIRDataset(Dataset): + """ Get L/H for denosing on AWGN with fixed sigma. + Ref: + DnCNN: Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising + Args: + opt (dict): A dictionary defining dataset-related parameters. + """ + + def __init__(self, opt=None): + super(SwinIRDataset, self).__init__() + + print( + 'Dataset: Denosing on AWGN with fixed sigma. Only dataroot_H is needed.' + ) + self.opt = opt + self.n_channels = opt['n_channels'] if opt['n_channels'] else 3 + self.patch_size = opt['H_size'] if opt['H_size'] else 64 + self.sigma = opt['sigma'] if opt['sigma'] else 25 + self.sigma_test = opt['sigma_test'] if opt['sigma_test'] else self.sigma + self.paths_H = get_image_paths(opt['dataroot_H']) + + def __len__(self): + return len(self.paths_H) + + def __getitem__(self, index): + # get H image + H_path = self.paths_H[index] + + img_H = imread_uint(H_path, self.n_channels) + + L_path = H_path + + if self.opt['phase'] == 'train': + # get L/H patch pairs + H, W, _ = img_H.shape + + # randomly crop the patch + rnd_h = random.randint(0, max(0, H - self.patch_size)) + rnd_w = random.randint(0, max(0, W - self.patch_size)) + patch_H = img_H[rnd_h:rnd_h + self.patch_size, + rnd_w:rnd_w + self.patch_size, :] + + # augmentation - flip, rotate + mode = random.randint(0, 7) + patch_H = augment_img(patch_H, mode=mode) + img_H = uint2tensor3(patch_H) + img_L = img_H.clone() + + # add noise + noise = paddle.randn(img_L.shape) * self.sigma / 255.0 + img_L = img_L + noise + + else: + # get L/H image pairs + img_H = uint2single(img_H) + img_L = np.copy(img_H) + + # add noise + np.random.seed(seed=0) + img_L += np.random.normal(0, self.sigma_test / 255.0, img_L.shape) + + # HWC to CHW, numpy to tensor + img_L = single2tensor3(img_L) + img_H = single2tensor3(img_H) + + filename = os.path.splitext(os.path.split(H_path)[-1])[0] + + return img_H, img_L, filename diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 65331da..e1e2aa7 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -38,3 +38,4 @@ from .singan_model import SinGANModel from .rcan_model import RCANModel from .prenet_model import PReNetModel from .gpen_model import GPENModel +from .swinir_model import SwinIRModel diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 630e105..afb89e9 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -42,3 +42,4 @@ from .generator_singan import SinGANGenerator from .rcan import RCAN from .prenet import PReNet from .gpen import GPEN +from .swinir import SwinIR diff --git a/ppgan/models/generators/swinir.py b/ppgan/models/generators/swinir.py new file mode 100644 index 0000000..e6c553a --- /dev/null +++ b/ppgan/models/generators/swinir.py @@ -0,0 +1,1060 @@ +# code was heavily based on https://github.com/cszn/KAIR +# MIT License +# Copyright (c) 2019 Kai Zhang +""" +Droppath, reimplement from https://github.com/yueatsprograms/Stochastic_Depth +""" +from itertools import repeat +import collections.abc +import math +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .builder import GENERATORS + + +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +class DropPath(nn.Layer): + """DropPath class""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def drop_path(self, inputs): + """drop path op + Args: + input: tensor with arbitrary shape + drop_prob: float number of drop path probability, default: 0.0 + training: bool, if current mode is training, default: False + Returns: + output: output tensor after drop path + """ + # if prob is 0 or eval mode, return original input + if self.drop_prob == 0. or not self.training: + return inputs + keep_prob = 1 - self.drop_prob + keep_prob = paddle.to_tensor(keep_prob, dtype='float32') + shape = ( + inputs.shape[0], ) + (1, ) * (inputs.ndim - 1) # shape=(N, 1, 1, 1) + random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype) + random_tensor = random_tensor.floor() # mask + output = inputs.divide( + keep_prob + ) * random_tensor # divide is to keep same output expectation + return output + + def forward(self, inputs): + return self.drop_path(inputs) + + +to_2tuple = _ntuple(2) + + +@paddle.jit.not_to_static +def swapdim(x, dim1, dim2): + a = list(range(len(x.shape))) + a[dim1], a[dim2] = a[dim2], a[dim1] + return x.transpose(a) + + +class Identity(nn.Layer): + """ Identity layer + The output of this layer is the input without any change. + Use this layer to avoid if condition in some forward methods + """ + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class Mlp(nn.Layer): + + def __init__(self, in_features, hidden_features, dropout): + super(Mlp, self).__init__() + w_attr_1, b_attr_1 = self._init_weights() + self.fc1 = nn.Linear(in_features, + hidden_features, + weight_attr=w_attr_1, + bias_attr=b_attr_1) + + w_attr_2, b_attr_2 = self._init_weights() + self.fc2 = nn.Linear(hidden_features, + in_features, + weight_attr=w_attr_2, + bias_attr=b_attr_2) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def _init_weights(self): + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.XavierUniform()) + bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Normal( + std=1e-6)) + return weight_attr, bias_attr + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class WindowAttention(nn.Layer): + """Window based multihead attention, with relative position bias. + Both shifted window and non-shifted window are supported. + Args: + dim (int): input dimension (channels) + window_size (int): height and width of the window + num_heads (int): number of attention heads + qkv_bias (bool): if True, enable learnable bias to q,k,v, default: True + qk_scale (float): override default qk scale head_dim**-0.5 if set, default: None + attention_dropout (float): dropout of attention + dropout (float): dropout for output + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attention_dropout=0., + dropout=0.): + super(WindowAttention, self).__init__() + self.window_size = window_size + self.num_heads = num_heads + self.dim = dim + self.dim_head = dim // num_heads + self.scale = qk_scale or self.dim_head**-0.5 + + self.relative_position_bias_table = paddle.create_parameter( + shape=[(2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads], + dtype='float32', + default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) + + weight_attr, bias_attr = self._init_weights() + + # relative position index for each token inside window + coords_h = paddle.arange(self.window_size[0]) + coords_w = paddle.arange(self.window_size[1]) + coords = paddle.stack(paddle.meshgrid([coords_h, coords_w + ])) # [2, window_h, window_w] + coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w] + # 2, window_h * window_w, window_h * window_h + relative_coords = coords_flatten.unsqueeze( + 2) - coords_flatten.unsqueeze(1) + # winwod_h*window_w, window_h*window_w, 2 + relative_coords = relative_coords.transpose([1, 2, 0]) + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + # [window_size * window_size, window_size*window_size] + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, + dim * 3, + weight_attr=weight_attr, + bias_attr=bias_attr if qkv_bias else False) + self.attn_dropout = nn.Dropout(attention_dropout) + self.proj = nn.Linear(dim, + dim, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.proj_dropout = nn.Dropout(dropout) + self.softmax = nn.Softmax(axis=-1) + + def transpose_multihead(self, x): + tensor_shape = list(x.shape[:-1]) + new_shape = tensor_shape + [self.num_heads, self.dim_head] + x = x.reshape(new_shape) + x = x.transpose([0, 2, 1, 3]) + return x + + def get_relative_pos_bias_from_pos_index(self): + # relative_position_bias_table is a ParamBase object + # https://github.com/PaddlePaddle/Paddle/blob/067f558c59b34dd6d8626aad73e9943cf7f5960f/python/paddle/fluid/framework.py#L5727 + table = self.relative_position_bias_table # N x num_heads + # index is a tensor + index = self.relative_position_index.reshape( + [-1]) # window_h*window_w * window_h*window_w + # NOTE: paddle does NOT support indexing Tensor by a Tensor + relative_position_bias = paddle.index_select(x=table, index=index) + return relative_position_bias + + def forward(self, x, mask=None): + qkv = self.qkv(x).chunk(3, axis=-1) + q, k, v = map(self.transpose_multihead, qkv) + q = q * self.scale + attn = paddle.matmul(q, k, transpose_y=True) + relative_position_bias = self.get_relative_pos_bias_from_pos_index() + relative_position_bias = relative_position_bias.reshape([ + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ]) + # nH, window_h*window_w, window_h*window_w + relative_position_bias = relative_position_bias.transpose([2, 0, 1]) + attn = attn + relative_position_bias.unsqueeze(0) + if mask is not None: + nW = mask.shape[0] + attn = attn.reshape( + [x.shape[0] // nW, nW, self.num_heads, x.shape[1], x.shape[1]]) + attn += mask.unsqueeze(1).unsqueeze(0) + attn = attn.reshape([-1, self.num_heads, x.shape[1], x.shape[1]]) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_dropout(attn) + + z = paddle.matmul(attn, v) + z = z.transpose([0, 2, 1, 3]) + tensor_shape = list(z.shape[:-2]) + new_shape = tensor_shape + [self.dim] + z = z.reshape(new_shape) + z = self.proj(z) + z = self.proj_dropout(z) + + return z + + def _init_weights(self): + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.TruncatedNormal(std=.02)) + bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0)) + return weight_attr, bias_attr + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + flops += N * self.dim * 3 * self.dim + flops += self.num_heads * N * (self.dim // self.num_heads) * N + flops += self.num_heads * N * N * (self.dim // self.num_heads) + flops += N * self.dim * self.dim + return flops + + +def windows_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.reshape( + [B, H // window_size, window_size, W // window_size, window_size, C]) + windows = x.transpose([0, 1, 3, 2, 4, + 5]).reshape([-1, window_size, window_size, C]) + return windows + + +def windows_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.reshape( + [B, H // window_size, W // window_size, window_size, window_size, -1]) + x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1]) + return x + + +class SwinTransformerBlock(nn.Layer): + """Swin transformer block + Contains window multi head self attention, droppath, mlp, norm and residual. + Attributes: + dim: int, input dimension (channels) + input_resolution: int, input resoultion + num_heads: int, number of attention heads + windos_size: int, window size, default: 7 + shift_size: int, shift size for SW-MSA, default: 0 + mlp_ratio: float, ratio of mlp hidden dim and input embedding dim, default: 4. + qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True + qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None + dropout: float, dropout for output, default: 0. + attention_dropout: float, dropout of attention, default: 0. + droppath: float, drop path rate, default: 0. + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + dropout=0., + attention_dropout=0., + droppath=0.): + super(SwinTransformerBlock, self).__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + self.norm1 = nn.LayerNorm(dim) + self.attn = WindowAttention(dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attention_dropout=attention_dropout, + dropout=dropout) + self.drop_path = DropPath(droppath) if droppath > 0. else Identity() + self.norm2 = nn.LayerNorm(dim) + self.mlp = Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + dropout=dropout) + + attn_mask = self.calculate_mask(self.input_resolution) + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = paddle.zeros((1, H, W, 1)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = windows_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.reshape( + [-1, self.window_size * self.window_size]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + + huns = -100.0 * paddle.ones_like(attn_mask) + attn_mask = huns * (attn_mask != 0).astype("float32") + + return attn_mask + else: + return None + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + + shortcut = x + x = self.norm1(x) + x = x.reshape([B, H, W, C]) + + # cyclic shift + if self.shift_size > 0: + shifted_x = paddle.roll(x, + shifts=(-self.shift_size, -self.shift_size), + axis=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = windows_partition(shifted_x, self.window_size) + x_windows = x_windows.reshape( + [-1, self.window_size * self.window_size, C]) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) + else: + attn_windows = self.attn(x_windows, + mask=self.calculate_mask(x_size)) + + # merge windows + attn_windows = attn_windows.reshape( + [-1, self.window_size, self.window_size, C]) + shifted_x = windows_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + if self.shift_size > 0: + x = paddle.roll(shifted_x, + shifts=(self.shift_size, self.shift_size), + axis=(1, 2)) + else: + x = shifted_x + + x = x.reshape([B, H * W, C]) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Layer): + """ Patch Merging class + Merge multiple patch into one path and keep the out dim. + Spefically, merge adjacent 2x2 patches(dim=C) into 1 patch. + The concat dim 4*C is rescaled to 2*C + Args: + input_resolution (tuple | ints): the size of input + dim: dimension of single patch + reduction: nn.Linear which maps 4C to 2C dim + norm: nn.LayerNorm, applied after linear layer. + """ + + def __init__(self, input_resolution, dim): + super(PatchMerging, self).__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False) + self.norm = nn.LayerNorm(4 * dim) + + def forward(self, x): + h, w = self.input_resolution + b, _, c = x.shape + x = x.reshape([b, h, w, c]) + + x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C] + x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C] + x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C] + x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C] + x = paddle.concat([x0, x1, x2, x3], -1) #[B, H/2, W/2, 4*C] + x = x.reshape([b, -1, 4 * c]) # [B, H/2*W/2, 4*C] + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Layer): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + dropout (float, optional): Dropout rate. Default: 0.0 + attention_dropout (float, optional): Attention dropout rate. Default: 0.0 + droppath (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + dropout=0., + attention_dropout=0., + droppath=0., + downsample=None): + super(BasicLayer, self).__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + self.blocks = nn.LayerList() + for i in range(depth): + self.blocks.append( + SwinTransformerBlock(dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if + (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + dropout=dropout, + attention_dropout=attention_dropout, + droppath=droppath[i] if isinstance( + droppath, list) else droppath)) + + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim) + else: + self.downsample = None + + def forward(self, x, x_size): + for block in self.blocks: + x = block(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Layer): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + downsample=None, + img_size=224, + patch_size=4, + resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + dropout=drop, + attention_dropout=attn_drop, + droppath=drop_path, + downsample=downsample) + + if resi_connection == '1conv': + self.conv = nn.Conv2D(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2D(dim, dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv2D(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv2D(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed(img_size=img_size, + patch_size=patch_size, + in_chans=0, + embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed(img_size=img_size, + patch_size=patch_size, + in_chans=0, + embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed( + self.conv(self.patch_unembed(self.residual_group(x, x_size), + x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Layer): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Layer, optional): Normalization layer. Default: None + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Layer): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Layer, optional): Normalization layer. Default: None + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose([0, 2, + 1]).reshape([B, self.embed_dim, x_size[0], + x_size[1]]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2D(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2D(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' + 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2D(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +@GENERATORS.register() +class SwinIR(nn.Layer): + r""" SwinIR + A Pypaddle impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=[6, 6, 6, 6], + num_heads=[6, 6, 6, 6], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + upscale=2, + img_range=1., + upsampler='', + resi_connection='1conv'): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = np.array([0.4488, 0.4371, 0.4040], dtype=np.float32) + self.mean = paddle.Tensor(rgb_mean).reshape([1, 3, 1, 1]) + else: + self.mean = paddle.zeros([1., 1., 1., 1.], dtype=paddle.float32) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + # 1. shallow feature extraction + self.conv_first = nn.Conv2D(num_in_ch, embed_dim, 3, 1, 1) + + # 2. deep feature extraction + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = paddle.nn.ParameterList([ + paddle.create_parameter( + shape=[1, num_patches, embed_dim], + dtype='float32', + default_initializer=paddle.nn.initializer.TruncatedNormal( + std=.02)) + ]) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.LayerList() + for i_layer in range(self.num_layers): + layer = RSTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer] + ):sum(depths[:i_layer + + 1])], # no impact on SR results + downsample=None, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2D(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2D(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv2D(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv2D(embed_dim // 4, embed_dim, 3, 1, 1)) + + # 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU()) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep( + upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential( + nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU()) + self.conv_up1 = nn.Conv2D(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2D(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2D(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2D(embed_dim, num_out_ch, 3, 1, 1) + + def no_weight_decay(self): + return {'absolute_pos_embed'} + + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.shape + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu( + self.conv_up1( + paddle.nn.functional.interpolate(x, + scale_factor=2, + mode='nearest'))) + x = self.lrelu( + self.conv_up2( + paddle.nn.functional.interpolate(x, + scale_factor=2, + mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H * self.upscale, :W * self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops diff --git a/ppgan/models/swinir_model.py b/ppgan/models/swinir_model.py new file mode 100644 index 0000000..9fc8c4e --- /dev/null +++ b/ppgan/models/swinir_model.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +import paddle.nn as nn + +from .builder import MODELS +from .base_model import BaseModel +from .generators.builder import build_generator +from .criterions.builder import build_criterion +from ppgan.utils.visual import tensor2img + + +@MODELS.register() +class SwinIRModel(BaseModel): + """SwinIR Model. + SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 + Originally Written by Ze Liu, Modified by Jingyun Liang. + """ + + def __init__(self, generator, char_criterion=None): + """Initialize the MPR class. + + Args: + generator (dict): config of generator. + char_criterion (dict): config of char criterion. + """ + super(SwinIRModel, self).__init__(generator) + self.current_iter = 1 + + self.nets['generator'] = build_generator(generator) + + if char_criterion: + self.char_criterion = build_criterion(char_criterion) + + def setup_input(self, input): + self.target = input[0] + self.lq = input[1] + + def train_iter(self, optims=None): + optims['optim'].clear_gradients() + + restored = self.nets['generator'](self.lq) + + loss = self.char_criterion(restored, self.target) + + loss.backward() + optims['optim'].step() + self.losses['loss'] = loss.numpy() + + def forward(self): + pass + + def test_iter(self, metrics=None): + self.nets['generator'].eval() + with paddle.no_grad(): + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + self.nets['generator'].train() + + out_img = [] + gt_img = [] + for out_tensor, gt_tensor in zip(self.output, self.target): + out_img.append(tensor2img(out_tensor, (0., 1.))) + gt_img.append(tensor2img(gt_tensor, (0., 1.))) + + if metrics is not None: + for metric in metrics.values(): + metric.update(out_img, gt_img) + + def export_model(self, + export_model=None, + output_dir=None, + inputs_size=None, + export_serving_model=False, + model_name=None): + shape = inputs_size[0] + new_model = self.nets['generator'] + new_model.eval() + input_spec = [paddle.static.InputSpec(shape=shape, dtype="float32")] + + static_model = paddle.jit.to_static(new_model, input_spec=input_spec) + + if output_dir is None: + output_dir = 'inference_model' + if model_name is None: + model_name = '{}_{}'.format(self.__class__.__name__.lower(), + export_model[0]['name']) + + paddle.jit.save(static_model, os.path.join(output_dir, model_name)) diff --git a/test_tipc/configs/swinir/train_infer_python.txt b/test_tipc/configs/swinir/train_infer_python.txt new file mode 100644 index 0000000..f41cec2 --- /dev/null +++ b/test_tipc/configs/swinir/train_infer_python.txt @@ -0,0 +1,51 @@ +===========================train_params=========================== +model_name:swinir +python:python3.7 +gpu_list:0 +## +auto_cast:null +total_iters:lite_train_lite_infer=10 +output_dir:./output/ +snapshot_config.interval:lite_train_lite_infer=10 +pretrained_model:null +train_model_name:swinir*/*checkpoint.pdparams +train_infer_img_dir:null +null:null +## +trainer:norm_train +norm_train:tools/main.py -c configs/swinir_denoising.yaml --seed 100 -o log_config.interval=1 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +--output_dir:./output/ +load:null +norm_export:tools/export_model.py -c configs/swinir_denoising.yaml --inputs_size=1,3,128,128 --model_name inference --load +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +inference_dir:inference +train_model:./inference/swinir/swinirmodel_generator +infer_export:null +infer_quant:False +inference:tools/inference.py --model_type swinir --seed 100 -c configs/swinir_denoising.yaml --output_path test_tipc/output/ +--device:gpu +null:null +null:null +null:null +null:null +null:null +--model_path: +null:null +null:null +--benchmark:True +null:null diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 0dcfac1..733788d 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -62,6 +62,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then rm -rf ./data/DIV2K* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14paddle.tar --no-check-certificate cd ./data/ && tar xf DIV2KandSet14paddle.tar && cd ../ ;; + swinir) + rm -rf ./data/*sets + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/swinir_data.zip --no-check-certificate + cd ./data/ && unzip -q swinir_data.zip && cd ../ ;; singan) rm -rf ./data/singan* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate diff --git a/tools/inference.py b/tools/inference.py index 246bacd..0c7110e 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -19,7 +19,7 @@ from ppgan.metrics import build_metric MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \ - "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan"] + "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan", "swinir"] def parse_args(): @@ -334,6 +334,63 @@ def main(): metric_file = os.path.join(args.output_path, "singan/metric.txt") for metric in metrics.values(): metric.update(prediction, data['A']) + elif model_type == "swinir": + lq = data[1].numpy() + _, _, h_old, w_old = lq.shape + window_size = 8 + tile = 128 + tile_overlap = 32 + # after feed data to model, shape of feature map is change + h_pad = (h_old // window_size + 1) * window_size - h_old + w_pad = (w_old // window_size + 1) * window_size - w_old + lq = np.concatenate([lq, np.flip(lq, 2)], + axis=2)[:, :, :h_old + h_pad, :] + lq = np.concatenate([lq, np.flip(lq, 3)], + axis=3)[:, :, :, :w_old + w_pad] + lq = lq.astype("float32") + + b, c, h, w = lq.shape + tile = min(tile, h, w) + assert tile % window_size == 0, "tile size should be a multiple of window_size" + sf = 1 # scale + stride = tile - tile_overlap + h_idx_list = list(range(0, h - tile, stride)) + [h - tile] + w_idx_list = list(range(0, w - tile, stride)) + [w - tile] + E = np.zeros([b, c, h * sf, w * sf], dtype=np.float32) + W = np.zeros_like(E) + + for h_idx in h_idx_list: + for w_idx in w_idx_list: + in_patch = lq[..., h_idx:h_idx + tile, w_idx:w_idx + tile] + input_handles[0].copy_from_cpu(in_patch) + predictor.run() + out_patch = output_handle.copy_to_cpu() + out_patch_mask = np.ones_like(out_patch) + + E[..., h_idx * sf:(h_idx + tile) * sf, + w_idx * sf:(w_idx + tile) * sf] += out_patch + W[..., h_idx * sf:(h_idx + tile) * sf, + w_idx * sf:(w_idx + tile) * sf] += out_patch_mask + + output = np.true_divide(E, W) + prediction = output[..., :h_old * sf, :w_old * sf] + + prediction = paddle.to_tensor(prediction) + target = tensor2img(data[0], (0., 1.)) + prediction = tensor2img(prediction, (0., 1.)) + + metric_file = os.path.join(args.output_path, model_type, + "metric.txt") + for metric in metrics.values(): + metric.update(prediction, target) + + lq = tensor2img(data[1], (0., 1.)) + + sample_result = np.concatenate((lq, prediction, target), 1) + sample = cv2.cvtColor(sample_result, cv2.COLOR_RGB2BGR) + file_name = os.path.join(args.output_path, model_type, + "{}.png".format(i)) + cv2.imwrite(file_name, sample) if metrics: log_file = open(metric_file, 'a') -- GitLab