未验证 提交 1c66a2b2 编写于 作者: K kongdebug 提交者: GitHub

【兴智杯复现赛】NAFNet (#707)

* 【兴智杯复现赛】NAFNet

* [feature] add TIPC and README.md
上级 9ee66047
......@@ -122,6 +122,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* [SwinIR](./docs/en_US/tutorials/swinir.md)
* [InvDN](./docs/en_US/tutorials/invdn.md)
* [AOT-GAN](./docs/en_US/tutorials/aotgan.md)
* [NAFNet](./docs/en_US/tutorials/nafnet.md)
## Composite Application
......
......@@ -138,7 +138,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
* 视频超分:[Video Super Resolution(VSR)](./docs/zh_CN/tutorials/video_super_resolution.md)
* 包含模型:⭐ PP-MSVSR ⭐、EDVR、BasicVSR、BasicVSR++
* 图像视频修复
* 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)[SwinIR](./docs/zh_CN/tutorials/swinir.md)[InvDN](./docs/zh_CN/tutorials/invdn.md)
* 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)[SwinIR](./docs/zh_CN/tutorials/swinir.md)[InvDN](./docs/zh_CN/tutorials/invdn.md)[NAFNet](./docs/zh_CN/tutorials/nafnet.md)
* 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md)
* 图像去雨:[PReNet](./docs/zh_CN/tutorials/prenet.md)
* 图像补全:[AOT-GAN](./docs/zh_CN/tutorials/aotgan.md)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import argparse
sys.path.insert(0, os.getcwd())
import paddle
from ppgan.apps import NAFNetPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--seed",
type=int,
default=None,
help="sample random seed for model's image generation")
parser.add_argument('--images_path',
default=None,
required=True,
type=str,
help='Single image or images directory.')
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = NAFNetPredictor(output_path=args.output_path,
weight_path=args.weight_path,
seed=args.seed)
predictor.run(images_path=args.images_path)
total_iters: 3200000
output_dir: output_dir
model:
name: NAFNetModel
generator:
name: NAFNet
img_channel: 3
width: 64
enc_blk_nums: [2, 2, 4, 8]
middle_blk_num: 12
dec_blk_nums: [2, 2, 2, 2]
psnr_criterion:
name: PSNRLoss
dataset:
train:
name: NAFNetTrain
rgb_dir: data/SIDD/train
num_workers: 16
batch_size: 8 # 1GPU
img_options:
patch_size: 256
test:
name: NAFNetVal
rgb_dir: data/SIDD/val
num_workers: 1
batch_size: 1
img_options:
patch_size: 256
export_model:
- {name: 'generator', inputs_num: 1}
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 125e-6 # num_gpu * 0.000125
periods: [3200000]
restart_weights: [1]
eta_min: !!float 1e-7
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
optimizer:
name: AdamW
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
weight_decay: 0.0
beta1: 0.9
beta2: 0.9
epsilon: 1e-8
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
English | [Chinese](../../zh_CN/tutorials/nafnet.md)
## NAFNet:Simple Baselines for Image Restoration
## 1、Introduction
NAFNet proposes an ultra-simple baseline scheme, Baseline, which is not only computationally efficient but also outperforms the previous SOTA scheme; the resulting Baseline is further simplified to give NAFNet: the non-linear activation units are removed and the performance is further improved. The proposed solution achieves new SOTA performance for both SIDD noise reduction and GoPro deblurring tasks with a significant reduction in computational effort. The network design and features are shown in the figure below, using a UNet with skip connections as the overall architecture, modifying the Transformer module in the Restormer block and eliminating the activation function, adopting a simpler and more efficient simplegate design, and applying a simpler channel attention mechanism.
![NAFNet](https://ai-studio-static-online.cdn.bcebos.com/699b87449c7e495f8655ae5ac8bc0eb77bed4d9cd828451e8939ddbc5732a704)
For a more detailed introduction to the model, please refer to the original paper [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676), PaddleGAN currently provides the weight of the denoising task.
## 2 How to use
### 2.1 Quick start
After installing PaddleGAN, you can run a command as follows to generate the restorated image.
```sh
python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE}
```
Where `PATH_OF_IMAGE` is the path of the image you need to denoise, or the path of the folder where the images is located. If you need to use your own model weights, run the following command, where `PATH_OF_MODEL` is the path to the model weights.
```sh
python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} --weight_path ${PATH_OF_MODEL}
```
### 2.2 Prepare dataset
The Denoising training datasets is SIDD, an image denoising datasets, containing 30,000 noisy images from 10 different lighting conditions, which can be downloaded from [training datasets](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) and [Test datasets](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su).
After downloading, decompress it to the data directory. After decompression, the structure of `SIDDdataset` is as following:
```sh
SIDD
├── train
│ ├── input
│ └── target
└── val
├── input
└── target
```
Users can also use the [SIDD data](https://aistudio.baidu.com/aistudio/datasetdetail/149460) on AI studio, but need to rename the folders `input_crops` and `gt_crops` to `input` and ` target`
### 2.3 Training
An example is training to denoising. If you want to train for other tasks,If you want to train other tasks, you can change the dataset and modify the config file.
```sh
python -u tools/main.py --config-file configs/nafnet_denoising.yaml
```
### 2.4 Test
test model:
```sh
python tools/main.py --config-file configs/nafnet_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
Denoising
| model | dataset | PSNR/SSIM |
|---|---|---|
| NAFNet | SIDD Val | 43.1468 / 0.9563 |
## 4 Download
| model | link |
|---|---|
| NAFNet| [NAFNet_Denoising](https://paddlegan.bj.bcebos.com/models/NAFNet_Denoising.pdparams) |
# References
- [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676)
```
@article{chen_simple_nodate,
title = {Simple {Baselines} for {Image} {Restoration}},
abstract = {Although there have been significant advances in the field of image restoration recently, the system complexity of the state-of-the-art (SOTA) methods is increasing as well, which may hinder the convenient analysis and comparison of methods. In this paper, we propose a simple baseline that exceeds the SOTA methods and is computationally efficient. To further simplify the baseline, we reveal that the nonlinear activation functions, e.g. Sigmoid, ReLU, GELU, Softmax, etc. are not necessary: they could be replaced by multiplication or removed. Thus, we derive a Nonlinear Activation Free Network, namely NAFNet, from the baseline. SOTA results are achieved on various challenging benchmarks, e.g. 33.69 dB PSNR on GoPro (for image deblurring), exceeding the previous SOTA 0.38 dB with only 8.4\% of its computational costs; 40.30 dB PSNR on SIDD (for image denoising), exceeding the previous SOTA 0.28 dB with less than half of its computational costs. The code and the pretrained models will be released at github.com/megvii-research/NAFNet.},
language = {en},
author = {Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
pages = {17}
}
```
[English](../../en_US/tutorials/nafnet.md) | 中文
# NAFNet:图像恢复的简单基线
## 1、简介
NAFNet提出一种超简基线方案Baseline,它不仅计算高效同时性能优于之前SOTA方案;在所得Baseline基础上进一步简化得到了NAFNet:移除了非线性激活单元且性能进一步提升。所提方案在SIDD降噪与GoPro去模糊任务上均达到了新的SOTA性能,同时计算量大幅降低。网络设计和特点如下图所示,采用带跳过连接的UNet作为整体架构,同时修改了Restormer块中的Transformer模块,并取消了激活函数,采取更简单有效的simplegate设计,运用更简单的通道注意力机制
![NAFNet](https://ai-studio-static-online.cdn.bcebos.com/699b87449c7e495f8655ae5ac8bc0eb77bed4d9cd828451e8939ddbc5732a704)
对模型更详细的介绍,可参考论文原文[Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676),PaddleGAN中目前提供去噪任务的权重
## 2 如何使用
### 2.1 快速体验
安装`PaddleGAN`之后进入`PaddleGAN`文件夹下,运行如下命令即生成修复后的图像`./output_dir/Denoising/image_name.png`
```sh
python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE}
```
其中`PATH_OF_IMAGE`为你需要去噪的图像路径,或图像所在文件夹的路径。若需要使用自己的模型权重,则运行如下命令,其中`PATH_OF_MODEL`为模型权重的路径
```sh
python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} --weight_path ${PATH_OF_MODEL}
```
### 2.2 数据准备
Denoising训练数据是SIDD,一个图像去噪数据集,包含来自10个不同光照条件下的3万幅噪声图像,可以从[训练数据集下载](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php)[测试数据集下载](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su)下载。
下载后解压到data目录下,解压完成后数据分布如下所示:
```sh
SIDD
├── train
│ ├── input
│ └── target
└── val
├── input
└── target
```
用户也可以使用AI studio上的[SIDD数据](https://aistudio.baidu.com/aistudio/datasetdetail/149460),但需要将文件夹`input_crops``gt_crops`重命名为`input``target`
### 2.3 训练
示例以训练Denoising的数据为例。如果想训练其他任务可以更换数据集并修改配置文件
```sh
python -u tools/main.py --config-file configs/nafnet_denoising.yaml
```
### 2.4 测试
测试模型:
```sh
python tools/main.py --config-file configs/nafnet_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 结果展示
去噪
| 模型 | 数据集 | PSNR/SSIM |
|---|---|---|
| NAFNet | SIDD Val | 43.1468 / 0.9563 |
## 4 模型下载
| 模型 | 下载地址 |
|---|---|
| NAFNet| [NAFNet_Denoising](https://paddlegan.bj.bcebos.com/models/NAFNet_Denoising.pdparams) |
# 参考文献
- [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676)
```
@article{chen_simple_nodate,
title = {Simple {Baselines} for {Image} {Restoration}},
abstract = {Although there have been significant advances in the field of image restoration recently, the system complexity of the state-of-the-art (SOTA) methods is increasing as well, which may hinder the convenient analysis and comparison of methods. In this paper, we propose a simple baseline that exceeds the SOTA methods and is computationally efficient. To further simplify the baseline, we reveal that the nonlinear activation functions, e.g. Sigmoid, ReLU, GELU, Softmax, etc. are not necessary: they could be replaced by multiplication or removed. Thus, we derive a Nonlinear Activation Free Network, namely NAFNet, from the baseline. SOTA results are achieved on various challenging benchmarks, e.g. 33.69 dB PSNR on GoPro (for image deblurring), exceeding the previous SOTA 0.38 dB with only 8.4\% of its computational costs; 40.30 dB PSNR on SIDD (for image denoising), exceeding the previous SOTA 0.28 dB with less than half of its computational costs. The code and the pretrained models will be released at github.com/megvii-research/NAFNet.},
language = {en},
author = {Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
pages = {17}
}
```
......@@ -39,4 +39,5 @@ from .singan_predictor import SinGANPredictor
from .gpen_predictor import GPENPredictor
from .swinir_predictor import SwinIRPredictor
from .invdn_predictor import InvDNPredictor
from .nafnet_predictor import NAFNetPredictor
from .aotgan_predictor import AOTGANPredictor
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
from glob import glob
from natsort import natsorted
import numpy as np
import os
import random
from tqdm import tqdm
import paddle
from ppgan.models.generators import NAFNet
from ppgan.utils.download import get_path_from_url
from .base_predictor import BasePredictor
model_cfgs = {
'Denoising': {
'model_urls':
'https://paddlegan.bj.bcebos.com/models/NAFNet_Denoising.pdparams',
'img_channel': 3,
'width': 64,
'enc_blk_nums': [2, 2, 4, 8],
'middle_blk_num': 12,
'dec_blk_nums': [2, 2, 2, 2]
}
}
class NAFNetPredictor(BasePredictor):
def __init__(self,
output_path='output_dir',
weight_path=None,
seed=None,
window_size=8):
self.output_path = output_path
task = 'Denoising'
self.task = task
self.window_size = window_size
if weight_path is None:
if task in model_cfgs.keys():
weight_path = get_path_from_url(model_cfgs[task]['model_urls'])
checkpoint = paddle.load(weight_path)
else:
raise ValueError('Predictor need a task to define!')
else:
if weight_path.startswith("http"): # os.path.islink dosen't work!
weight_path = get_path_from_url(weight_path)
checkpoint = paddle.load(weight_path)
else:
checkpoint = paddle.load(weight_path)
self.generator = NAFNet(
img_channel=model_cfgs[task]['img_channel'],
width=model_cfgs[task]['width'],
enc_blk_nums=model_cfgs[task]['enc_blk_nums'],
middle_blk_num=model_cfgs[task]['middle_blk_num'],
dec_blk_nums=model_cfgs[task]['dec_blk_nums'])
checkpoint = checkpoint['generator']
self.generator.set_state_dict(checkpoint)
self.generator.eval()
if seed is not None:
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_images(self, images_path):
if os.path.isdir(images_path):
return natsorted(
glob(os.path.join(images_path, '*.jpeg')) +
glob(os.path.join(images_path, '*.jpg')) +
glob(os.path.join(images_path, '*.JPG')) +
glob(os.path.join(images_path, '*.png')) +
glob(os.path.join(images_path, '*.PNG')))
else:
return [images_path]
def imread_uint(self, path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
img = np.expand_dims(img, axis=2) # HxWx1
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return img
def uint2single(self, img):
return np.float32(img / 255.)
# convert single (HxWxC) to 3-dimensional paddle tensor
def single2tensor3(self, img):
return paddle.Tensor(np.ascontiguousarray(
img, dtype=np.float32)).transpose([2, 0, 1])
def run(self, images_path=None):
os.makedirs(self.output_path, exist_ok=True)
task_path = os.path.join(self.output_path, self.task)
os.makedirs(task_path, exist_ok=True)
image_files = self.get_images(images_path)
for image_file in tqdm(image_files):
img_L = self.imread_uint(image_file, 3)
image_name = os.path.basename(image_file)
img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(task_path, image_name), img)
tmps = image_name.split('.')
assert len(
tmps) == 2, f'Invalid image name: {image_name}, too much "."'
restoration_save_path = os.path.join(
task_path, f'{tmps[0]}_restoration.{tmps[1]}')
img_L = self.uint2single(img_L)
# HWC to CHW, numpy to tensor
img_L = self.single2tensor3(img_L)
img_L = img_L.unsqueeze(0)
with paddle.no_grad():
output = self.generator(img_L)
restored = paddle.clip(output, 0, 1)
restored = restored.numpy()
restored = restored.transpose(0, 2, 3, 1)
restored = restored[0]
restored = restored * 255
restored = restored.astype(np.uint8)
cv2.imwrite(restoration_save_path,
cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
print('Done, output path is:', task_path)
......@@ -35,4 +35,5 @@ from .swinir_dataset import SwinIRDataset
from .gfpgan_datasets import FFHQDegradationDataset
from .paired_image_datasets import PairedImageDataset
from .invdn_dataset import InvDNDataset
from .nafnet_dataset import NAFNetTrain, NAFNetVal, NAFNetTest
from .aotgan_dataset import AOTGANDataset
# code was heavily based on https://github.com/swz30/MPRNet
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/swz30/MPRNet/blob/main/LICENSE.md
import os
import random
import numpy as np
from PIL import Image
from paddle.io import Dataset
from .builder import DATASETS
from paddle.vision.transforms.functional import to_tensor, adjust_brightness, adjust_saturation, rotate, hflip, hflip, vflip, center_crop
def is_image_file(filename):
return any(
filename.endswith(extension)
for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
@DATASETS.register()
class NAFNetTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(NAFNetTrain, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [
os.path.join(rgb_dir, 'input', x) for x in inp_files
if is_image_file(x)
]
self.tar_filenames = [
os.path.join(rgb_dir, 'target', x) for x in tar_files
if is_image_file(x)
]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
w, h = tar_img.size
padw = ps - w if w < ps else 0
padh = ps - h if h < ps else 0
# Reflect Pad in case image is smaller than patch_size
if padw != 0 or padh != 0:
inp_img = np.pad(inp_img, (0, 0, padw, padh),
padding_mode='reflect')
tar_img = np.pad(tar_img, (0, 0, padw, padh),
padding_mode='reflect')
aug = random.randint(0, 2)
if aug == 1:
inp_img = adjust_brightness(inp_img, 1)
tar_img = adjust_brightness(tar_img, 1)
aug = random.randint(0, 2)
if aug == 1:
sat_factor = 1 + (0.2 - 0.4 * np.random.rand())
inp_img = adjust_saturation(inp_img, sat_factor)
tar_img = adjust_saturation(tar_img, sat_factor)
# Data Augmentations
aug = random.randint(0, 8)
if aug == 1:
inp_img = vflip(inp_img)
tar_img = vflip(tar_img)
elif aug == 2:
inp_img = hflip(inp_img)
tar_img = hflip(tar_img)
elif aug == 3:
inp_img = rotate(inp_img, 90)
tar_img = rotate(tar_img, 90)
elif aug == 4:
inp_img = rotate(inp_img, 90 * 2)
tar_img = rotate(tar_img, 90 * 2)
elif aug == 5:
inp_img = rotate(inp_img, 90 * 3)
tar_img = rotate(tar_img, 90 * 3)
elif aug == 6:
inp_img = rotate(vflip(inp_img), 90)
tar_img = rotate(vflip(tar_img), 90)
elif aug == 7:
inp_img = rotate(hflip(inp_img), 90)
tar_img = rotate(hflip(tar_img), 90)
inp_img = to_tensor(inp_img)
tar_img = to_tensor(tar_img)
hh, ww = tar_img.shape[1], tar_img.shape[2]
rr = random.randint(0, hh - ps)
cc = random.randint(0, ww - ps)
# Crop patch
inp_img = inp_img[:, rr:rr + ps, cc:cc + ps]
tar_img = tar_img[:, rr:rr + ps, cc:cc + ps]
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
@DATASETS.register()
class NAFNetVal(Dataset):
def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
super(NAFNetVal, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [
os.path.join(rgb_dir, 'input', x) for x in inp_files
if is_image_file(x)
]
self.tar_filenames = [
os.path.join(rgb_dir, 'target', x) for x in tar_files
if is_image_file(x)
]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
# Validate on center crop
if self.ps is not None:
inp_img = center_crop(inp_img, (ps, ps))
tar_img = center_crop(tar_img, (ps, ps))
inp_img = to_tensor(inp_img)
tar_img = to_tensor(tar_img)
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
@DATASETS.register()
class NAFNetTest(Dataset):
def __init__(self, inp_dir, img_options):
super(NAFNetTest, self).__init__()
inp_files = sorted(os.listdir(inp_dir))
self.inp_filenames = [
os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)
]
self.inp_size = len(self.inp_filenames)
self.img_options = img_options
def __len__(self):
return self.inp_size
def __getitem__(self, index):
path_inp = self.inp_filenames[index]
filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
inp = Image.open(path_inp)
inp = to_tensor(inp)
return inp, filename
......@@ -41,4 +41,5 @@ from .gpen_model import GPENModel
from .swinir_model import SwinIRModel
from .gfpgan_model import GFPGANModel
from .invdn_model import InvDNModel
from .nafnet_model import NAFNetModel
from .aotgan_model import AOTGANModel
......@@ -2,7 +2,7 @@ from .gan_loss import GANLoss
from .perceptual_loss import PerceptualLoss
from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss, \
CalcStyleEmdLoss, CalcContentReltLoss, \
CalcContentLoss, CalcStyleLoss, EdgeLoss
CalcContentLoss, CalcStyleLoss, EdgeLoss, PSNRLoss
from .photopen_perceptual_loss import PhotoPenPerceptualLoss
from .gradient_penalty import GradientPenalty
......
......@@ -31,6 +31,7 @@ class L1Loss():
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
......@@ -59,6 +60,7 @@ class CharbonnierLoss():
eps (float): Default: 1e-12.
"""
def __init__(self, eps=1e-12, reduction='sum'):
self.eps = eps
self.reduction = reduction
......@@ -90,6 +92,7 @@ class MSELoss():
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
......@@ -119,6 +122,7 @@ class BCEWithLogitsLoss():
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
......@@ -161,6 +165,7 @@ def calc_emd_loss(pred, target):
class CalcStyleEmdLoss():
"""Calc Style Emd Loss.
"""
def __init__(self):
super(CalcStyleEmdLoss, self).__init__()
......@@ -183,6 +188,7 @@ class CalcStyleEmdLoss():
class CalcContentReltLoss():
"""Calc Content Relt Loss.
"""
def __init__(self):
super(CalcContentReltLoss, self).__init__()
......@@ -207,6 +213,7 @@ class CalcContentReltLoss():
class CalcContentLoss():
"""Calc Content Loss.
"""
def __init__(self):
self.mse_loss = nn.MSELoss()
......@@ -229,6 +236,7 @@ class CalcContentLoss():
class CalcStyleLoss():
"""Calc Style Loss.
"""
def __init__(self):
self.mse_loss = nn.MSELoss()
......@@ -247,6 +255,7 @@ class CalcStyleLoss():
@CRITERIONS.register()
class EdgeLoss():
def __init__(self):
k = paddle.to_tensor([[.05, .25, .4, .25, .05]])
self.kernel = paddle.matmul(k.t(), k).unsqueeze(0).tile([3, 1, 1, 1])
......@@ -271,3 +280,27 @@ class EdgeLoss():
y.stop_gradient = True
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
@CRITERIONS.register()
class PSNRLoss(nn.Layer):
def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
super(PSNRLoss, self).__init__()
assert reduction == 'mean'
self.loss_weight = loss_weight
self.scale = 10 / np.log(10)
self.toY = toY
self.coef = paddle.to_tensor(np.array([65.481, 128.553,
24.966])).reshape([1, 3, 1, 1])
def forward(self, pred, target):
if self.toY:
pred = (pred * self.coef).sum(axis=1).unsqueeze(axis=1) + 16.
target = (target * self.coef).sum(axis=1).unsqueeze(axis=1) + 16.
pred, target = pred / 255., target / 255.
pass
return self.loss_weight * self.scale * paddle.log((
(pred - target)**2).mean(axis=[1, 2, 3]) + 1e-8).mean()
......@@ -46,4 +46,5 @@ from .swinir import SwinIR
from .gfpganv1_clean_arch import GFPGANv1Clean
from .gfpganv1_arch import GFPGANv1, StyleGAN2DiscriminatorGFPGAN
from .invdn import InvDN
from .nafnet import NAFNet, NAFNetLocal
from .generater_aotgan import InpaintGenerator
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
import paddle
from paddle import nn as nn
import paddle.nn.functional as F
from paddle.autograd import PyLayer
from .builder import GENERATORS
class LayerNormFunction(PyLayer):
@staticmethod
def forward(ctx, x, weight, bias, eps):
ctx.eps = eps
N, C, H, W = x.shape
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.reshape([1, C, 1, 1]) * y + bias.reshape([1, C, 1, 1])
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
N, C, H, W = grad_output.shape
y, var, weight = ctx.saved_tensor()
g = grad_output * weight.reshape([1, C, 1, 1])
mean_g = g.mean(axis=1, keepdim=True)
mean_gy = (g * y).mean(axis=1, keepdim=True)
gx = 1. / paddle.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return gx, (grad_output * y).sum(axis=3).sum(axis=2).sum(
axis=0), grad_output.sum(axis=3).sum(axis=2).sum(axis=0)
class LayerNorm2D(nn.Layer):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2D, self).__init__()
self.add_parameter(
'weight',
self.create_parameter(
[channels],
default_initializer=paddle.nn.initializer.Constant(value=1.0)))
self.add_parameter(
'bias',
self.create_parameter(
[channels],
default_initializer=paddle.nn.initializer.Constant(value=0.0)))
self.eps = eps
def forward(self, x):
if self.training:
y = LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
else:
N, C, H, W = x.shape
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + self.eps).sqrt()
y = self.weight.reshape([1, C, 1, 1]) * y + self.bias.reshape(
[1, C, 1, 1])
return y
class AvgPool2D(nn.Layer):
def __init__(self,
kernel_size=None,
base_size=None,
auto_pad=True,
fast_imp=False,
train_size=None):
super().__init__()
self.kernel_size = kernel_size
self.base_size = base_size
self.auto_pad = auto_pad
# only used for fast implementation
self.fast_imp = fast_imp
self.rs = [5, 4, 3, 2, 1]
self.max_r1 = self.rs[0]
self.max_r2 = self.rs[0]
self.train_size = train_size
def extra_repr(self) -> str:
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp)
def forward(self, x):
if self.kernel_size is None and self.base_size:
train_size = self.train_size
if isinstance(self.base_size, int):
self.base_size = (self.base_size, self.base_size)
self.kernel_size = list(self.base_size)
self.kernel_size[
0] = x.shape[2] * self.base_size[0] // train_size[-2]
self.kernel_size[
1] = x.shape[3] * self.base_size[1] // train_size[-1]
# only used for fast implementation
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
if self.kernel_size[0] >= x.shape[-2] and self.kernel_size[
1] >= x.shape[-1]:
return F.adaptive_avg_pool2d(x, 1)
if self.fast_imp: # Non-equivalent implementation but faster
h, w = x.shape[2:]
if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
out = F.adaptive_avg_pool2d(x, 1)
else:
r1 = [r for r in self.rs if h % r == 0][0]
r2 = [r for r in self.rs if w % r == 0][0]
# reduction_constraint
r1 = min(self.max_r1, r1)
r2 = min(self.max_r2, r2)
s = x[:, :, ::r1, ::r2].cumsum(axis=-1).cumsum(axis=-2)
n, c, h, w = s.shape
k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(
w - 1, self.kernel_size[1] // r2)
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] -
s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
out = paddle.nn.functional.interpolate(out,
scale_factor=(r1, r2))
else:
n, c, h, w = x.shape
s = x.cumsum(axis=-1).cumsum(axis=-2)
s = paddle.nn.functional.pad(s,
[1, 0, 1, 0]) # pad 0 for convenience
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1,
k2:], s[:, :,
k1:, :-k2], s[:, :,
k1:,
k2:]
out = s4 + s1 - s2 - s3
out = out / (k1 * k2)
if self.auto_pad:
n, c, h, w = x.shape
_h, _w = out.shape[2:]
pad2d = [(w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2,
(h - _h + 1) // 2]
out = paddle.nn.functional.pad(out, pad2d, mode='replicate')
return out
def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
for n, m in model.named_children():
if len(list(m.children())) > 0:
## compound module, go inside it
replace_layers(m, base_size, train_size, fast_imp, **kwargs)
if isinstance(m, nn.AdaptiveAvgPool2D):
pool = AvgPool2D(base_size=base_size,
fast_imp=fast_imp,
train_size=train_size)
assert m._output_size == 1
setattr(model, n, pool)
'''
ref.
@article{chu2021tlsc,
title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
journal={arXiv preprint arXiv:2112.04491},
year={2021}
}
'''
class Local_Base():
def convert(self, *args, train_size, **kwargs):
replace_layers(self, *args, train_size=train_size, **kwargs)
imgs = paddle.rand(train_size)
with paddle.no_grad():
self.forward(imgs)
class SimpleGate(nn.Layer):
def forward(self, x):
x1, x2 = x.chunk(2, axis=1)
return x1 * x2
class NAFBlock(nn.Layer):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2D(in_channels=c,
out_channels=dw_channel,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias_attr=True)
self.conv2 = nn.Conv2D(in_channels=dw_channel,
out_channels=dw_channel,
kernel_size=3,
padding=1,
stride=1,
groups=dw_channel,
bias_attr=True)
self.conv3 = nn.Conv2D(in_channels=dw_channel // 2,
out_channels=c,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias_attr=True)
# Simplified Channel Attention
self.sca = nn.Sequential(
nn.AdaptiveAvgPool2D(1),
nn.Conv2D(in_channels=dw_channel // 2,
out_channels=dw_channel // 2,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias_attr=True),
)
# SimpleGate
self.sg = SimpleGate()
ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2D(in_channels=c,
out_channels=ffn_channel,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias_attr=True)
self.conv5 = nn.Conv2D(in_channels=ffn_channel // 2,
out_channels=c,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias_attr=True)
self.norm1 = LayerNorm2D(c)
self.norm2 = LayerNorm2D(c)
self.drop_out_rate = drop_out_rate
self.dropout1 = nn.Dropout(
drop_out_rate) if drop_out_rate > 0. else None
self.dropout2 = nn.Dropout(
drop_out_rate) if drop_out_rate > 0. else None
self.add_parameter(
"beta",
self.create_parameter(
[1, c, 1, 1],
default_initializer=paddle.nn.initializer.Constant(value=0.0)))
self.add_parameter(
"gamma",
self.create_parameter(
[1, c, 1, 1],
default_initializer=paddle.nn.initializer.Constant(value=0.0)))
def forward(self, inp):
x = inp
x = self.norm1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = x * self.sca(x)
x = self.conv3(x)
if self.drop_out_rate > 0:
x = self.dropout1(x)
y = inp + x * self.beta
x = self.conv4(self.norm2(y))
x = self.sg(x)
x = self.conv5(x)
if self.drop_out_rate > 0:
x = self.dropout2(x)
return y + x * self.gamma
@GENERATORS.register()
class NAFNet(nn.Layer):
def __init__(self,
img_channel=3,
width=16,
middle_blk_num=1,
enc_blk_nums=[],
dec_blk_nums=[]):
super().__init__()
self.intro = nn.Conv2D(in_channels=img_channel,
out_channels=width,
kernel_size=3,
padding=1,
stride=1,
groups=1,
bias_attr=True)
self.ending = nn.Conv2D(in_channels=width,
out_channels=img_channel,
kernel_size=3,
padding=1,
stride=1,
groups=1,
bias_attr=True)
self.encoders = nn.LayerList()
self.decoders = nn.LayerList()
self.middle_blks = nn.LayerList()
self.ups = nn.LayerList()
self.downs = nn.LayerList()
chan = width
for num in enc_blk_nums:
self.encoders.append(
nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
self.downs.append(nn.Conv2D(chan, 2 * chan, 2, 2))
chan = chan * 2
self.middle_blks = \
nn.Sequential(
*[NAFBlock(chan) for _ in range(middle_blk_num)]
)
for num in dec_blk_nums:
self.ups.append(
nn.Sequential(nn.Conv2D(chan, chan * 2, 1, bias_attr=False),
nn.PixelShuffle(2)))
chan = chan // 2
self.decoders.append(
nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
self.padder_size = 2**len(self.encoders)
def forward(self, inp):
B, C, H, W = inp.shape
inp = self.check_image_size(inp)
x = self.intro(inp)
encs = []
for encoder, down in zip(self.encoders, self.downs):
x = encoder(x)
encs.append(x)
x = down(x)
x = self.middle_blks(x)
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
x = up(x)
x = x + enc_skip
x = decoder(x)
x = self.ending(x)
x = x + inp
return x[:, :, :H, :W]
def check_image_size(self, x):
_, _, h, w = x.shape
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
x = F.pad(x, [0, mod_pad_w, 0, mod_pad_h])
return x
@GENERATORS.register()
class NAFNetLocal(Local_Base, NAFNet):
def __init__(self,
*args,
train_size=(1, 3, 256, 256),
fast_imp=False,
**kwargs):
Local_Base.__init__(self)
NAFNet.__init__(self, *args, **kwargs)
N, C, H, W = train_size
base_size = (int(H * 1.5), int(W * 1.5))
self.eval()
with paddle.no_grad():
self.convert(base_size=base_size,
train_size=train_size,
fast_imp=fast_imp)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
from .builder import MODELS
from .base_model import BaseModel
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from ..utils.visual import tensor2img
@MODELS.register()
class NAFNetModel(BaseModel):
"""NAFNet Model.
Paper: Simple Baselines for Image Restoration
https://arxiv.org/pdf/2204.04676
"""
def __init__(self, generator, psnr_criterion=None):
"""Initialize the MPR class.
Args:
generator (dict): config of generator.
psnr_criterion (dict): config of psnr criterion.
"""
super(NAFNetModel, self).__init__(generator)
self.current_iter = 1
self.nets['generator'] = build_generator(generator)
if psnr_criterion:
self.psnr_criterion = build_criterion(psnr_criterion)
def setup_input(self, input):
self.target = input[0]
self.lq = input[1]
def train_iter(self, optims=None):
optims['optim'].clear_gradients()
restored = self.nets['generator'](self.lq)
loss = self.psnr_criterion(restored, self.target)
loss.backward()
optims['optim'].step()
self.losses['loss'] = loss.numpy()
def forward(self):
pass
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.target):
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
def export_model(self,
export_model=None,
output_dir=None,
inputs_size=None,
export_serving_model=False,
model_name=None):
shape = inputs_size[0]
new_model = self.nets['generator']
new_model.eval()
input_spec = [paddle.static.InputSpec(shape=shape, dtype="float32")]
static_model = paddle.jit.to_static(new_model, input_spec=input_spec)
if output_dir is None:
output_dir = 'inference_model'
if model_name is None:
model_name = '{}_{}'.format(self.__class__.__name__.lower(),
export_model[0]['name'])
paddle.jit.save(static_model, os.path.join(output_dir, model_name))
......@@ -21,3 +21,4 @@ OPTIMIZERS.register(paddle.optimizer.Adam)
OPTIMIZERS.register(paddle.optimizer.SGD)
OPTIMIZERS.register(paddle.optimizer.Momentum)
OPTIMIZERS.register(paddle.optimizer.RMSProp)
OPTIMIZERS.register(paddle.optimizer.AdamW)
===========================train_params===========================
model_name:nafnet
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10
output_dir:./output/
snapshot_config.interval:lite_train_lite_infer=10
pretrained_model:null
train_model_name:nafnet*/*checkpoint.pdparams
train_infer_img_dir:null
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/nafnet_denoising.yaml --seed 100 -o log_config.interval=1
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/nafnet_denoising.yaml --inputs_size=1,3,256,256 --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/nafnet/nafnetmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type nafnet --seed 100 -c configs/nafnet_denoising.yaml --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
......@@ -70,6 +70,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
rm -rf ./data/SIDD_*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/SIDD_mini.zip --no-check-certificate
cd ./data/ && unzip -q SIDD_mini.zip && cd ../ ;;
nafnet)
rm -rf ./data/SIDD*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/SIDD_mini.zip --no-check-certificate
cd ./data/ && unzip -q SIDD_mini.zip && mkdir -p SIDD && mv ./SIDD_Medium_Srgb_Patches_512/* ./SIDD/ \
&& mv ./SIDD_Valid_Srgb_Patches_256/* ./SIDD/ && mv ./SIDD/valid ./SIDD/val \
&& mv ./SIDD/train/GT ./SIDD/train/target && mv ./SIDD/train/Noisy ./SIDD/train/input \
&& mv ./SIDD/val/Noisy ./SIDD/val/input && mv ./SIDD/val/GT ./SIDD/val/target && cd ../ ;;
singan)
rm -rf ./data/singan*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate
......
......@@ -19,7 +19,8 @@ from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan", "swinir", "invdn", "aotgan"]
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", \
"singan", "swinir", "invdn", "aotgan", "nafnet"]
def parse_args():
......@@ -423,8 +424,31 @@ def main():
for metric in metrics.values():
metric.update(image_numpy, gt_numpy)
break
elif model_type == "nafnet":
lq = data[1].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
target = tensor2img(data[0], (0., 1.))
prediction = tensor2img(prediction, (0., 1.))
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(prediction, target)
lq = tensor2img(data[1], (0., 1.))
sample_result = np.concatenate((lq, prediction, target), 1)
sample = cv2.cvtColor(sample_result, cv2.COLOR_RGB2BGR)
file_name = os.path.join(args.output_path, model_type,
"{}.png".format(i))
cv2.imwrite(file_name, sample)
elif model_type == 'aotgan':
input_data = paddle.concat((data['img'], data['mask']), axis=1).numpy()
input_data = paddle.concat((data['img'], data['mask']),
axis=1).numpy()
input_handles[0].copy_from_cpu(input_data)
predictor.run()
prediction = output_handle.copy_to_cpu()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册