未验证 提交 349bcb10 编写于 作者: K kongdebug 提交者: GitHub

【论文复现赛】SwinIR (#692)

* add SwinIR model

* [fix] add swinir

* [fix] fix url bug

* [fix] fix something for pr

* [feature] add README.md for SwinIR

* [fix] fix the testsets path

* fix testsets path again

* [fix] fix test data path
上级 b2f18921
......@@ -119,6 +119,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* [MPR Net](./docs/en_US/tutorials/mpr_net.md)
* [FaceEnhancement](./docs/en_US/tutorials/face_enhancement.md)
* [PReNet](./docs/en_US/tutorials/prenet.md)
* [SwinIR](./docs/en_US/tutorials/swinir.md)
## Composite Application
......
......@@ -138,7 +138,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
* 视频超分:[Video Super Resolution(VSR)](./docs/zh_CN/tutorials/video_super_resolution.md)
* 包含模型:⭐ PP-MSVSR ⭐、EDVR、BasicVSR、BasicVSR++
* 图像视频修复
* 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)
* 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)[SwinIR](./docs/zh_CN/tutorials/swinir.md)
* 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md)
* 图像去雨:[PReNet](./docs/zh_CN/tutorials/prenet.md)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import argparse
import paddle
from ppgan.apps import SwinIRPredictor
sys.path.insert(0, os.getcwd())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--seed",
type=int,
default=None,
help="sample random seed for model's image generation")
parser.add_argument('--images_path',
default=None,
required=True,
type=str,
help='Single image or images directory.')
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = SwinIRPredictor(output_path=args.output_path,
weight_path=args.weight_path,
seed=args.seed)
predictor.run(images_path=args.images_path)
total_iters: 6400000
output_dir: output_dir
model:
name: SwinIRModel
generator:
name: SwinIR
upscale: 1
img_size: 128
window_size: 8
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
char_criterion:
name: CharbonnierLoss
eps: 0.000000001
reduction: mean
dataset:
train:
name: SwinIRDataset
num_workers: 8
batch_size: 2 # 1GPU
opt:
phase: train
n_channels: 3
H_size: 128
sigma: 15
sigma_test: 15
dataroot_H: data/trainsets/trainH
test:
name: SwinIRDataset
num_workers: 1
batch_size: 1
opt:
phase: test
n_channels: 3
H_size: 128
sigma: 15
sigma_test: 15
dataroot_H: data/trainsets/CBSD68
export_model:
- {name: 'generator', inputs_num: 1}
lr_scheduler:
name: MultiStepDecay
learning_rate: 5e-5 # num_gpu * 5e-5
milestones: [3200000, 4800000, 5600000, 6000000, 6400000]
gamma: 0.5
validate:
interval: 200
save_img: True
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 500
English | [Chinese](../../zh_CN/tutorials/swinir.md)
## SwinIR Strong Baseline Model for Image Restoration Based on Swin Transformer
## 1、Introduction
The structure of SwinIR is relatively simple. If you have seen Swin-Transformer, there is no difficulty. The authors introduce the Swin-T structure for low-level vision tasks, including image super-resolution reconstruction, image denoising, and image compression artifact removal. The SwinIR network consists of a shallow feature extraction module, a deep feature extraction module and a reconstruction module. The reconstruction module uses different structures for different tasks. Shallow feature extraction is a 3×3 convolutional layer. Deep feature extraction is composed of k RSTB blocks and a convolutional layer plus residual connections. Each RSTB (Res-Swin-Transformer-Block) consists of L STLs and a layer of convolution plus residual connections. The structure of the model is shown in the following figure:
![](https://ai-studio-static-online.cdn.bcebos.com/b550e84915634951af756a545c643c815001be73372248b0b5179fd1652ae003)
For a more detailed introduction to the model, please refer to the original paper [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf), PaddleGAN currently provides the weight of the denoising task.
## 2 How to use
### 2.1 Quick start
After installing PaddleGAN, you can run a command as follows to generate the restorated image.
```sh
python applications/tools/swinir_denoising.py --images_path ${PATH_OF_IMAGE}
```
Where `PATH_OF_IMAGE` is the path of the image you need to denoise, or the path of the folder where the images is located.
### 2.2 Prepare dataset
#### Train Dataset
[DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (800 training images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) + [BSD500](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz) (400 training&testing images) + [WED](http://ivc.uwaterloo.ca/database/WaterlooExploration/exploration_database_and_code.rar)(4744 images)
The data that has been sorted out: put it in [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/149405).
The training data is placed under: `data/trainsets/trainH`
#### Test Dataset
The test data is CBSD68: put it in [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/147756).
Extract to: `data/triansets/CBSD68`
### 2.3 Training
An example is training to denoising. If you want to train for other tasks,If you want to train other tasks, you can change the dataset and modify the config file.
```sh
python -u tools/main.py --config-file configs/swinir_denoising.yaml
```
### 2.4 Test
test model:
```sh
python tools/main.py --config-file configs/swinir_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
Denoising
| model | dataset | PSNR/SSIM |
|---|---|---|
| SwinIR | CBSD68 | 36.0819 / 0.9464 |
## 4 Download
| model | link |
|---|---|
| SwinIR| [SwinIR_Denoising](https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams) |
# References
- [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf)
```
@article{liang2021swinir,
title={SwinIR: Image Restoration Using Swin Transformer},
author={Liang, Jingyun and Cao, Jiezhang and Sun, Guolei and Zhang, Kai and Van Gool, Luc and Timofte, Radu},
journal={arXiv preprint arXiv:2108.10257},
year={2021}
}
```
[English](../../en_US/tutorials/swinir.md) | 中文
## SwinIR 基于Swin Transformer的用于图像恢复的强基线模型
## 1、简介
SwinIR的结构比较简单,如果看过Swin-Transformer的话就没什么难点了。作者引入Swin-T结构应用于低级视觉任务,包括图像超分辨率重建、图像去噪、图像压缩伪影去除。SwinIR网络由一个浅层特征提取模块、深层特征提取模块、重建模块构成。重建模块对不同的任务使用不同的结构。浅层特征提取就是一个3×3的卷积层。深层特征提取是k个RSTB块和一个卷积层加残差连接构成。每个RSTB(Res-Swin-Transformer-Block)由L个STL和一层卷积加残差连接构成。模型的结构如下图所示:
![](https://ai-studio-static-online.cdn.bcebos.com/b550e84915634951af756a545c643c815001be73372248b0b5179fd1652ae003)
对模型更详细的介绍,可参考论文原文[SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf),PaddleGAN中目前提供去噪任务的权重
## 2 如何使用
### 2.1 快速体验
安装`PaddleGAN`之后进入`PaddleGAN`文件夹下,运行如下命令即生成修复后的图像`./output_dir/Denoising/image_name.png`
```sh
python applications/tools/swinir_denoising.py --images_path ${PATH_OF_IMAGE}
```
其中`PATH_OF_IMAGE`为你需要去噪的图像路径,或图像所在文件夹的路径
### 2.2 数据准备
#### 训练数据
[DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (800 training images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) + [BSD500](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz) (400 training&testing images) + [WED](http://ivc.uwaterloo.ca/database/WaterlooExploration/exploration_database_and_code.rar)(4744 images)
已经整理好的数据:放在了 [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/149405) 里.
训练数据放在:`data/trainsets/trainH`
#### 测试数据
测试数据为 CBSD68:放在了 [Ai Studio](https://aistudio.baidu.com/aistudio/datasetdetail/147756) 里.
解压到:`data/triansets/CBSD68`
- 经过处理之后,`PaddleGAN/data`文件夹下的
```sh
trainsets
├── trainH
| |-- 101085.png
| |-- 101086.png
| |-- ......
│ └── 201085.png
└── CBSD68
├── 271035.png
|-- ......
└── 351093.png
```
### 2.3 训练
示例以训练Denoising的数据为例。如果想训练其他任务可以更换数据集并修改配置文件
```sh
python -u tools/main.py --config-file configs/swinir_denoising.yaml
```
### 2.4 测试
测试模型:
```sh
python tools/main.py --config-file configs/swinir_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 结果展示
去噪
| 模型 | 数据集 | PSNR/SSIM |
|---|---|---|
| SwinIR | CBSD68 | 36.0819 / 0.9464 |
## 4 模型下载
| 模型 | 下载地址 |
|---|---|
| SwinIR| [SwinIR_Denoising](https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams) |
# 参考文献
- [SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/pdf/2108.10257.pdf)
```
@article{liang2021swinir,
title={SwinIR: Image Restoration Using Swin Transformer},
author={Liang, Jingyun and Cao, Jiezhang and Sun, Guolei and Zhang, Kai and Van Gool, Luc and Timofte, Radu},
journal={arXiv preprint arXiv:2108.10257},
year={2021}
}
```
......@@ -37,3 +37,4 @@ from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \
PPMSVSRLargePredictor)
from .singan_predictor import SinGANPredictor
from .gpen_predictor import GPENPredictor
from .swinir_predictor import SwinIRPredictor
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
from glob import glob
from natsort import natsorted
import numpy as np
import os
import random
from tqdm import tqdm
import paddle
from ppgan.models.generators import SwinIR
from ppgan.utils.download import get_path_from_url
from .base_predictor import BasePredictor
model_cfgs = {
'Denoising': {
'model_urls':
'https://paddlegan.bj.bcebos.com/models/SwinIR_Denoising.pdparams',
'upscale': 1,
'img_size': 128,
'window_size': 8,
'depths': [6, 6, 6, 6, 6, 6],
'embed_dim': 180,
'num_heads': [6, 6, 6, 6, 6, 6],
'mlp_ratio': 2
}
}
class SwinIRPredictor(BasePredictor):
def __init__(self,
output_path='output_dir',
weight_path=None,
seed=None,
window_size=8):
self.output_path = output_path
task = 'Denoising'
self.task = task
self.window_size = window_size
if weight_path is None:
if task in model_cfgs.keys():
weight_path = get_path_from_url(model_cfgs[task]['model_urls'])
checkpoint = paddle.load(weight_path)
else:
raise ValueError('Predictor need a task to define!')
else:
if weight_path.startswith("http"): # os.path.islink dosen't work!
weight_path = get_path_from_url(weight_path)
checkpoint = paddle.load(weight_path)
else:
checkpoint = paddle.load(weight_path)
self.generator = SwinIR(upscale=model_cfgs[task]['upscale'],
img_size=model_cfgs[task]['img_size'],
window_size=model_cfgs[task]['window_size'],
depths=model_cfgs[task]['depths'],
embed_dim=model_cfgs[task]['embed_dim'],
num_heads=model_cfgs[task]['num_heads'],
mlp_ratio=model_cfgs[task]['mlp_ratio'])
checkpoint = checkpoint['generator']
self.generator.set_state_dict(checkpoint)
self.generator.eval()
if seed is not None:
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_images(self, images_path):
if os.path.isdir(images_path):
return natsorted(
glob(os.path.join(images_path, '*.jpeg')) +
glob(os.path.join(images_path, '*.jpg')) +
glob(os.path.join(images_path, '*.JPG')) +
glob(os.path.join(images_path, '*.png')) +
glob(os.path.join(images_path, '*.PNG')))
else:
return [images_path]
def imread_uint(self, path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
img = np.expand_dims(img, axis=2) # HxWx1
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return img
def uint2single(self, img):
return np.float32(img / 255.)
# convert single (HxWxC) to 3-dimensional paddle tensor
def single2tensor3(self, img):
return paddle.Tensor(np.ascontiguousarray(
img, dtype=np.float32)).transpose([2, 0, 1])
def run(self, images_path=None):
os.makedirs(self.output_path, exist_ok=True)
task_path = os.path.join(self.output_path, self.task)
os.makedirs(task_path, exist_ok=True)
image_files = self.get_images(images_path)
for image_file in tqdm(image_files):
img_L = self.imread_uint(image_file, 3)
image_name = os.path.basename(image_file)
img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(task_path, image_name), img)
tmps = image_name.split('.')
assert len(
tmps) == 2, f'Invalid image name: {image_name}, too much "."'
restoration_save_path = os.path.join(
task_path, f'{tmps[0]}_restoration.{tmps[1]}')
img_L = self.uint2single(img_L)
# HWC to CHW, numpy to tensor
img_L = self.single2tensor3(img_L)
img_L = img_L.unsqueeze(0)
with paddle.no_grad():
# pad input image to be a multiple of window_size
_, _, h_old, w_old = img_L.shape
h_pad = (h_old // self.window_size +
1) * self.window_size - h_old
w_pad = (w_old // self.window_size +
1) * self.window_size - w_old
img_L = paddle.concat([img_L, paddle.flip(img_L, [2])],
2)[:, :, :h_old + h_pad, :]
img_L = paddle.concat([img_L, paddle.flip(img_L, [3])],
3)[:, :, :, :w_old + w_pad]
output = self.generator(img_L)
output = output[..., :h_old, :w_old]
restored = paddle.clip(output, 0, 1)
restored = restored.numpy()
restored = restored.transpose(0, 2, 3, 1)
restored = restored[0]
restored = restored * 255
restored = restored.astype(np.uint8)
cv2.imwrite(restoration_save_path,
cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
print('Done, output path is:', task_path)
......@@ -31,3 +31,4 @@ from .vsr_folder_dataset import VSRFolderDataset
from .photopen_dataset import PhotoPenDataset
from .empty_dataset import EmptyDataset
from .gpen_dataset import GPENDataset
from .swinir_dataset import SwinIRDataset
# code was heavily based on https://github.com/cszn/KAIR
# MIT License
# Copyright (c) 2019 Kai Zhang
import os
import random
import numpy as np
import cv2
import paddle
from paddle.io import Dataset
from .builder import DATASETS
def is_image_file(filename):
return any(
filename.endswith(extension)
for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
def get_image_paths(dataroot):
paths = None # return None if dataroot is None
if isinstance(dataroot, str):
paths = sorted(_get_paths_from_images(dataroot))
elif isinstance(dataroot, list):
paths = []
for i in dataroot:
paths += sorted(_get_paths_from_images(i))
return paths
def _get_paths_from_images(path):
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return images
def imread_uint(path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
img = np.expand_dims(img, axis=2) # HxWx1
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return img
def augment_img(img, mode=0):
if mode == 0:
return img
elif mode == 1:
return np.flipud(np.rot90(img))
elif mode == 2:
return np.flipud(img)
elif mode == 3:
return np.rot90(img, k=3)
elif mode == 4:
return np.flipud(np.rot90(img, k=2))
elif mode == 5:
return np.rot90(img)
elif mode == 6:
return np.rot90(img, k=2)
elif mode == 7:
return np.flipud(np.rot90(img, k=3))
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return paddle.Tensor(np.ascontiguousarray(img, dtype=np.float32)).transpose(
[2, 0, 1]) / 255.
def uint2single(img):
return np.float32(img / 255.)
# convert single (HxWxC) to 3-dimensional paddle tensor
def single2tensor3(img):
return paddle.Tensor(np.ascontiguousarray(img, dtype=np.float32)).transpose(
[2, 0, 1])
@DATASETS.register()
class SwinIRDataset(Dataset):
""" Get L/H for denosing on AWGN with fixed sigma.
Ref:
DnCNN: Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising
Args:
opt (dict): A dictionary defining dataset-related parameters.
"""
def __init__(self, opt=None):
super(SwinIRDataset, self).__init__()
print(
'Dataset: Denosing on AWGN with fixed sigma. Only dataroot_H is needed.'
)
self.opt = opt
self.n_channels = opt['n_channels'] if opt['n_channels'] else 3
self.patch_size = opt['H_size'] if opt['H_size'] else 64
self.sigma = opt['sigma'] if opt['sigma'] else 25
self.sigma_test = opt['sigma_test'] if opt['sigma_test'] else self.sigma
self.paths_H = get_image_paths(opt['dataroot_H'])
def __len__(self):
return len(self.paths_H)
def __getitem__(self, index):
# get H image
H_path = self.paths_H[index]
img_H = imread_uint(H_path, self.n_channels)
L_path = H_path
if self.opt['phase'] == 'train':
# get L/H patch pairs
H, W, _ = img_H.shape
# randomly crop the patch
rnd_h = random.randint(0, max(0, H - self.patch_size))
rnd_w = random.randint(0, max(0, W - self.patch_size))
patch_H = img_H[rnd_h:rnd_h + self.patch_size,
rnd_w:rnd_w + self.patch_size, :]
# augmentation - flip, rotate
mode = random.randint(0, 7)
patch_H = augment_img(patch_H, mode=mode)
img_H = uint2tensor3(patch_H)
img_L = img_H.clone()
# add noise
noise = paddle.randn(img_L.shape) * self.sigma / 255.0
img_L = img_L + noise
else:
# get L/H image pairs
img_H = uint2single(img_H)
img_L = np.copy(img_H)
# add noise
np.random.seed(seed=0)
img_L += np.random.normal(0, self.sigma_test / 255.0, img_L.shape)
# HWC to CHW, numpy to tensor
img_L = single2tensor3(img_L)
img_H = single2tensor3(img_H)
filename = os.path.splitext(os.path.split(H_path)[-1])[0]
return img_H, img_L, filename
......@@ -38,3 +38,4 @@ from .singan_model import SinGANModel
from .rcan_model import RCANModel
from .prenet_model import PReNetModel
from .gpen_model import GPENModel
from .swinir_model import SwinIRModel
......@@ -42,3 +42,4 @@ from .generator_singan import SinGANGenerator
from .rcan import RCAN
from .prenet import PReNet
from .gpen import GPEN
from .swinir import SwinIR
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
from .builder import MODELS
from .base_model import BaseModel
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from ppgan.utils.visual import tensor2img
@MODELS.register()
class SwinIRModel(BaseModel):
"""SwinIR Model.
SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
Originally Written by Ze Liu, Modified by Jingyun Liang.
"""
def __init__(self, generator, char_criterion=None):
"""Initialize the MPR class.
Args:
generator (dict): config of generator.
char_criterion (dict): config of char criterion.
"""
super(SwinIRModel, self).__init__(generator)
self.current_iter = 1
self.nets['generator'] = build_generator(generator)
if char_criterion:
self.char_criterion = build_criterion(char_criterion)
def setup_input(self, input):
self.target = input[0]
self.lq = input[1]
def train_iter(self, optims=None):
optims['optim'].clear_gradients()
restored = self.nets['generator'](self.lq)
loss = self.char_criterion(restored, self.target)
loss.backward()
optims['optim'].step()
self.losses['loss'] = loss.numpy()
def forward(self):
pass
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.target):
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
def export_model(self,
export_model=None,
output_dir=None,
inputs_size=None,
export_serving_model=False,
model_name=None):
shape = inputs_size[0]
new_model = self.nets['generator']
new_model.eval()
input_spec = [paddle.static.InputSpec(shape=shape, dtype="float32")]
static_model = paddle.jit.to_static(new_model, input_spec=input_spec)
if output_dir is None:
output_dir = 'inference_model'
if model_name is None:
model_name = '{}_{}'.format(self.__class__.__name__.lower(),
export_model[0]['name'])
paddle.jit.save(static_model, os.path.join(output_dir, model_name))
===========================train_params===========================
model_name:swinir
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10
output_dir:./output/
snapshot_config.interval:lite_train_lite_infer=10
pretrained_model:null
train_model_name:swinir*/*checkpoint.pdparams
train_infer_img_dir:null
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/swinir_denoising.yaml --seed 100 -o log_config.interval=1
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/swinir_denoising.yaml --inputs_size=1,3,128,128 --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/swinir/swinirmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type swinir --seed 100 -c configs/swinir_denoising.yaml --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
......@@ -62,6 +62,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
rm -rf ./data/DIV2K*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14paddle.tar --no-check-certificate
cd ./data/ && tar xf DIV2KandSet14paddle.tar && cd ../ ;;
swinir)
rm -rf ./data/*sets
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/swinir_data.zip --no-check-certificate
cd ./data/ && unzip -q swinir_data.zip && cd ../ ;;
singan)
rm -rf ./data/singan*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate
......
......@@ -19,7 +19,7 @@ from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan"]
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan", "swinir"]
def parse_args():
......@@ -334,6 +334,63 @@ def main():
metric_file = os.path.join(args.output_path, "singan/metric.txt")
for metric in metrics.values():
metric.update(prediction, data['A'])
elif model_type == "swinir":
lq = data[1].numpy()
_, _, h_old, w_old = lq.shape
window_size = 8
tile = 128
tile_overlap = 32
# after feed data to model, shape of feature map is change
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
lq = np.concatenate([lq, np.flip(lq, 2)],
axis=2)[:, :, :h_old + h_pad, :]
lq = np.concatenate([lq, np.flip(lq, 3)],
axis=3)[:, :, :, :w_old + w_pad]
lq = lq.astype("float32")
b, c, h, w = lq.shape
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = 1 # scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = np.zeros([b, c, h * sf, w * sf], dtype=np.float32)
W = np.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = lq[..., h_idx:h_idx + tile, w_idx:w_idx + tile]
input_handles[0].copy_from_cpu(in_patch)
predictor.run()
out_patch = output_handle.copy_to_cpu()
out_patch_mask = np.ones_like(out_patch)
E[..., h_idx * sf:(h_idx + tile) * sf,
w_idx * sf:(w_idx + tile) * sf] += out_patch
W[..., h_idx * sf:(h_idx + tile) * sf,
w_idx * sf:(w_idx + tile) * sf] += out_patch_mask
output = np.true_divide(E, W)
prediction = output[..., :h_old * sf, :w_old * sf]
prediction = paddle.to_tensor(prediction)
target = tensor2img(data[0], (0., 1.))
prediction = tensor2img(prediction, (0., 1.))
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(prediction, target)
lq = tensor2img(data[1], (0., 1.))
sample_result = np.concatenate((lq, prediction, target), 1)
sample = cv2.cvtColor(sample_result, cv2.COLOR_RGB2BGR)
file_name = os.path.join(args.output_path, model_type,
"{}.png".format(i))
cv2.imwrite(file_name, sample)
if metrics:
log_file = open(metric_file, 'a')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册