From 1c66a2b26fe67e44630100436fffd34a35638841 Mon Sep 17 00:00:00 2001 From: kongdebug <52785738+kongdebug@users.noreply.github.com> Date: Wed, 19 Oct 2022 10:21:03 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=85=B4=E6=99=BA=E6=9D=AF=E5=A4=8D?= =?UTF-8?q?=E7=8E=B0=E8=B5=9B=E3=80=91NAFNet=20(#707)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 【兴智杯复现赛】NAFNet * [feature] add TIPC and README.md --- README.md | 1 + README_cn.md | 2 +- applications/tools/nafnet_denoising.py | 59 +++ configs/nafnet_denoising.yaml | 72 ++++ docs/en_US/tutorials/nafnet.md | 87 ++++ docs/zh_CN/tutorials/nafnet.md | 87 ++++ ppgan/apps/__init__.py | 1 + ppgan/apps/nafnet_predictor.py | 155 +++++++ ppgan/datasets/__init__.py | 1 + ppgan/datasets/nafnet_dataset.py | 193 +++++++++ ppgan/models/__init__.py | 1 + ppgan/models/criterions/__init__.py | 2 +- ppgan/models/criterions/pixel_loss.py | 33 ++ ppgan/models/generators/__init__.py | 1 + ppgan/models/generators/nafnet.py | 407 ++++++++++++++++++ ppgan/models/nafnet_model.py | 104 +++++ ppgan/solver/optimizer.py | 1 + .../configs/nafnet/train_infer_python.txt | 51 +++ test_tipc/prepare.sh | 7 + tools/inference.py | 28 +- 20 files changed, 1289 insertions(+), 4 deletions(-) create mode 100644 applications/tools/nafnet_denoising.py create mode 100644 configs/nafnet_denoising.yaml create mode 100644 docs/en_US/tutorials/nafnet.md create mode 100644 docs/zh_CN/tutorials/nafnet.md create mode 100644 ppgan/apps/nafnet_predictor.py mode change 100755 => 100644 ppgan/datasets/__init__.py create mode 100644 ppgan/datasets/nafnet_dataset.py mode change 100755 => 100644 ppgan/models/generators/__init__.py create mode 100644 ppgan/models/generators/nafnet.py create mode 100644 ppgan/models/nafnet_model.py create mode 100644 test_tipc/configs/nafnet/train_infer_python.txt diff --git a/README.md b/README.md index 9b2354a..25c066e 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional * [SwinIR](./docs/en_US/tutorials/swinir.md) * [InvDN](./docs/en_US/tutorials/invdn.md) * [AOT-GAN](./docs/en_US/tutorials/aotgan.md) +* [NAFNet](./docs/en_US/tutorials/nafnet.md) ## Composite Application diff --git a/README_cn.md b/README_cn.md index 48f53fa..1a2a5b8 100644 --- a/README_cn.md +++ b/README_cn.md @@ -138,7 +138,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆) * 视频超分:[Video Super Resolution(VSR)](./docs/zh_CN/tutorials/video_super_resolution.md) * 包含模型:⭐ PP-MSVSR ⭐、EDVR、BasicVSR、BasicVSR++ * 图像视频修复 - * 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)、[SwinIR](./docs/zh_CN/tutorials/swinir.md)、[InvDN](./docs/zh_CN/tutorials/invdn.md) + * 图像去模糊去噪去雨:[MPR Net](./docs/zh_CN/tutorials/mpr_net.md)、[SwinIR](./docs/zh_CN/tutorials/swinir.md)、[InvDN](./docs/zh_CN/tutorials/invdn.md)、[NAFNet](./docs/zh_CN/tutorials/nafnet.md) * 视频去模糊:[EDVR](./docs/zh_CN/tutorials/video_super_resolution.md) * 图像去雨:[PReNet](./docs/zh_CN/tutorials/prenet.md) * 图像补全:[AOT-GAN](./docs/zh_CN/tutorials/aotgan.md) diff --git a/applications/tools/nafnet_denoising.py b/applications/tools/nafnet_denoising.py new file mode 100644 index 0000000..7ad6e7a --- /dev/null +++ b/applications/tools/nafnet_denoising.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import sys +import argparse + +sys.path.insert(0, os.getcwd()) +import paddle +from ppgan.apps import NAFNetPredictor + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_path", + type=str, + default='output_dir', + help="path to output image dir") + + parser.add_argument("--weight_path", + type=str, + default=None, + help="path to model checkpoint path") + + parser.add_argument("--seed", + type=int, + default=None, + help="sample random seed for model's image generation") + + parser.add_argument('--images_path', + default=None, + required=True, + type=str, + help='Single image or images directory.') + + parser.add_argument("--cpu", + dest="cpu", + action="store_true", + help="cpu mode.") + + args = parser.parse_args() + + if args.cpu: + paddle.set_device('cpu') + + predictor = NAFNetPredictor(output_path=args.output_path, + weight_path=args.weight_path, + seed=args.seed) + predictor.run(images_path=args.images_path) diff --git a/configs/nafnet_denoising.yaml b/configs/nafnet_denoising.yaml new file mode 100644 index 0000000..f3a8dda --- /dev/null +++ b/configs/nafnet_denoising.yaml @@ -0,0 +1,72 @@ +total_iters: 3200000 +output_dir: output_dir + +model: + name: NAFNetModel + generator: + name: NAFNet + img_channel: 3 + width: 64 + enc_blk_nums: [2, 2, 4, 8] + middle_blk_num: 12 + dec_blk_nums: [2, 2, 2, 2] + psnr_criterion: + name: PSNRLoss + +dataset: + train: + name: NAFNetTrain + rgb_dir: data/SIDD/train + num_workers: 16 + batch_size: 8 # 1GPU + img_options: + patch_size: 256 + test: + name: NAFNetVal + rgb_dir: data/SIDD/val + num_workers: 1 + batch_size: 1 + img_options: + patch_size: 256 + +export_model: + - {name: 'generator', inputs_num: 1} + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 125e-6 # num_gpu * 0.000125 + periods: [3200000] + restart_weights: [1] + eta_min: !!float 1e-7 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 4 + test_y_channel: True + ssim: + name: SSIM + crop_border: 4 + test_y_channel: True + +optimizer: + name: AdamW + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + weight_decay: 0.0 + beta1: 0.9 + beta2: 0.9 + epsilon: 1e-8 + +log_config: + interval: 10 + visiual_interval: 5000 + +snapshot_config: + interval: 5000 diff --git a/docs/en_US/tutorials/nafnet.md b/docs/en_US/tutorials/nafnet.md new file mode 100644 index 0000000..9bb9a2d --- /dev/null +++ b/docs/en_US/tutorials/nafnet.md @@ -0,0 +1,87 @@ +English | [Chinese](../../zh_CN/tutorials/nafnet.md) + +## NAFNet:Simple Baselines for Image Restoration + +## 1、Introduction + +NAFNet proposes an ultra-simple baseline scheme, Baseline, which is not only computationally efficient but also outperforms the previous SOTA scheme; the resulting Baseline is further simplified to give NAFNet: the non-linear activation units are removed and the performance is further improved. The proposed solution achieves new SOTA performance for both SIDD noise reduction and GoPro deblurring tasks with a significant reduction in computational effort. The network design and features are shown in the figure below, using a UNet with skip connections as the overall architecture, modifying the Transformer module in the Restormer block and eliminating the activation function, adopting a simpler and more efficient simplegate design, and applying a simpler channel attention mechanism. + +![NAFNet](https://ai-studio-static-online.cdn.bcebos.com/699b87449c7e495f8655ae5ac8bc0eb77bed4d9cd828451e8939ddbc5732a704) + +For a more detailed introduction to the model, please refer to the original paper [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676), PaddleGAN currently provides the weight of the denoising task. + +## 2 How to use + +### 2.1 Quick start + +After installing PaddleGAN, you can run a command as follows to generate the restorated image. + +```sh +python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} +``` +Where `PATH_OF_IMAGE` is the path of the image you need to denoise, or the path of the folder where the images is located. If you need to use your own model weights, run the following command, where `PATH_OF_MODEL` is the path to the model weights. + +```sh +python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} --weight_path ${PATH_OF_MODEL} +``` + +### 2.2 Prepare dataset + +The Denoising training datasets is SIDD, an image denoising datasets, containing 30,000 noisy images from 10 different lighting conditions, which can be downloaded from [training datasets](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) and [Test datasets](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su). +After downloading, decompress it to the data directory. After decompression, the structure of `SIDDdataset` is as following: + +```sh +SIDD +├── train +│ ├── input +│ └── target +└── val + ├── input + └── target + +``` +Users can also use the [SIDD data](https://aistudio.baidu.com/aistudio/datasetdetail/149460) on AI studio, but need to rename the folders `input_crops` and `gt_crops` to `input` and ` target` + +### 2.3 Training +An example is training to denoising. If you want to train for other tasks,If you want to train other tasks, you can change the dataset and modify the config file. + +```sh +python -u tools/main.py --config-file configs/nafnet_denoising.yaml +``` + +### 2.4 Test + +test model: +```sh +python tools/main.py --config-file configs/nafnet_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT} +``` + +## 3 Results +Denoising +| model | dataset | PSNR/SSIM | +|---|---|---| +| NAFNet | SIDD Val | 43.1468 / 0.9563 | + +## 4 Download + +| model | link | +|---|---| +| NAFNet| [NAFNet_Denoising](https://paddlegan.bj.bcebos.com/models/NAFNet_Denoising.pdparams) | + +# References + +- [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676) + +``` +@article{chen_simple_nodate, + title = {Simple {Baselines} for {Image} {Restoration}}, + abstract = {Although there have been significant advances in the field of image restoration recently, the system complexity of the state-of-the-art (SOTA) methods is increasing as well, which may hinder the convenient analysis and comparison of methods. In this paper, we propose a simple baseline that exceeds the SOTA methods and is computationally efficient. To further simplify the baseline, we reveal that the nonlinear activation functions, e.g. Sigmoid, ReLU, GELU, Softmax, etc. are not necessary: they could be replaced by multiplication or removed. Thus, we derive a Nonlinear Activation Free Network, namely NAFNet, from the baseline. SOTA results are achieved on various challenging benchmarks, e.g. 33.69 dB PSNR on GoPro (for image deblurring), exceeding the previous SOTA 0.38 dB with only 8.4\% of its computational costs; 40.30 dB PSNR on SIDD (for image denoising), exceeding the previous SOTA 0.28 dB with less than half of its computational costs. The code and the pretrained models will be released at github.com/megvii-research/NAFNet.}, + language = {en}, + author = {Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, + pages = {17} +} +``` + + + + diff --git a/docs/zh_CN/tutorials/nafnet.md b/docs/zh_CN/tutorials/nafnet.md new file mode 100644 index 0000000..c193d2c --- /dev/null +++ b/docs/zh_CN/tutorials/nafnet.md @@ -0,0 +1,87 @@ +[English](../../en_US/tutorials/nafnet.md) | 中文 + +# NAFNet:图像恢复的简单基线 + +## 1、简介 + +NAFNet提出一种超简基线方案Baseline,它不仅计算高效同时性能优于之前SOTA方案;在所得Baseline基础上进一步简化得到了NAFNet:移除了非线性激活单元且性能进一步提升。所提方案在SIDD降噪与GoPro去模糊任务上均达到了新的SOTA性能,同时计算量大幅降低。网络设计和特点如下图所示,采用带跳过连接的UNet作为整体架构,同时修改了Restormer块中的Transformer模块,并取消了激活函数,采取更简单有效的simplegate设计,运用更简单的通道注意力机制 + +![NAFNet](https://ai-studio-static-online.cdn.bcebos.com/699b87449c7e495f8655ae5ac8bc0eb77bed4d9cd828451e8939ddbc5732a704) + +对模型更详细的介绍,可参考论文原文[Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676),PaddleGAN中目前提供去噪任务的权重 + +## 2 如何使用 + +### 2.1 快速体验 + +安装`PaddleGAN`之后进入`PaddleGAN`文件夹下,运行如下命令即生成修复后的图像`./output_dir/Denoising/image_name.png` + +```sh +python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} +``` +其中`PATH_OF_IMAGE`为你需要去噪的图像路径,或图像所在文件夹的路径。若需要使用自己的模型权重,则运行如下命令,其中`PATH_OF_MODEL`为模型权重的路径 + +```sh +python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} --weight_path ${PATH_OF_MODEL} +``` + +### 2.2 数据准备 + +Denoising训练数据是SIDD,一个图像去噪数据集,包含来自10个不同光照条件下的3万幅噪声图像,可以从[训练数据集下载](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php)和[测试数据集下载](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su)下载。 +下载后解压到data目录下,解压完成后数据分布如下所示: + +```sh +SIDD +├── train +│ ├── input +│ └── target +└── val + ├── input + └── target + +``` +用户也可以使用AI studio上的[SIDD数据](https://aistudio.baidu.com/aistudio/datasetdetail/149460),但需要将文件夹`input_crops`与`gt_crops`重命名为`input`和`target` + +### 2.3 训练 +示例以训练Denoising的数据为例。如果想训练其他任务可以更换数据集并修改配置文件 + +```sh +python -u tools/main.py --config-file configs/nafnet_denoising.yaml +``` + +### 2.4 测试 + +测试模型: +```sh +python tools/main.py --config-file configs/nafnet_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT} +``` + +## 3 结果展示 + +去噪 +| 模型 | 数据集 | PSNR/SSIM | +|---|---|---| +| NAFNet | SIDD Val | 43.1468 / 0.9563 | + +## 4 模型下载 + +| 模型 | 下载地址 | +|---|---| +| NAFNet| [NAFNet_Denoising](https://paddlegan.bj.bcebos.com/models/NAFNet_Denoising.pdparams) | + + + +# 参考文献 + +- [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676) + +``` +@article{chen_simple_nodate, + title = {Simple {Baselines} for {Image} {Restoration}}, + abstract = {Although there have been significant advances in the field of image restoration recently, the system complexity of the state-of-the-art (SOTA) methods is increasing as well, which may hinder the convenient analysis and comparison of methods. In this paper, we propose a simple baseline that exceeds the SOTA methods and is computationally efficient. To further simplify the baseline, we reveal that the nonlinear activation functions, e.g. Sigmoid, ReLU, GELU, Softmax, etc. are not necessary: they could be replaced by multiplication or removed. Thus, we derive a Nonlinear Activation Free Network, namely NAFNet, from the baseline. SOTA results are achieved on various challenging benchmarks, e.g. 33.69 dB PSNR on GoPro (for image deblurring), exceeding the previous SOTA 0.38 dB with only 8.4\% of its computational costs; 40.30 dB PSNR on SIDD (for image denoising), exceeding the previous SOTA 0.28 dB with less than half of its computational costs. The code and the pretrained models will be released at github.com/megvii-research/NAFNet.}, + language = {en}, + author = {Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, + pages = {17} +} +``` + diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index ad49bea..37bce43 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -39,4 +39,5 @@ from .singan_predictor import SinGANPredictor from .gpen_predictor import GPENPredictor from .swinir_predictor import SwinIRPredictor from .invdn_predictor import InvDNPredictor +from .nafnet_predictor import NAFNetPredictor from .aotgan_predictor import AOTGANPredictor diff --git a/ppgan/apps/nafnet_predictor.py b/ppgan/apps/nafnet_predictor.py new file mode 100644 index 0000000..b7f7fef --- /dev/null +++ b/ppgan/apps/nafnet_predictor.py @@ -0,0 +1,155 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +from glob import glob +from natsort import natsorted +import numpy as np +import os +import random +from tqdm import tqdm + +import paddle + +from ppgan.models.generators import NAFNet +from ppgan.utils.download import get_path_from_url +from .base_predictor import BasePredictor + +model_cfgs = { + 'Denoising': { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/NAFNet_Denoising.pdparams', + 'img_channel': 3, + 'width': 64, + 'enc_blk_nums': [2, 2, 4, 8], + 'middle_blk_num': 12, + 'dec_blk_nums': [2, 2, 2, 2] + } +} + + +class NAFNetPredictor(BasePredictor): + + def __init__(self, + output_path='output_dir', + weight_path=None, + seed=None, + window_size=8): + self.output_path = output_path + task = 'Denoising' + self.task = task + self.window_size = window_size + + if weight_path is None: + if task in model_cfgs.keys(): + weight_path = get_path_from_url(model_cfgs[task]['model_urls']) + checkpoint = paddle.load(weight_path) + else: + raise ValueError('Predictor need a task to define!') + else: + if weight_path.startswith("http"): # os.path.islink dosen't work! + weight_path = get_path_from_url(weight_path) + checkpoint = paddle.load(weight_path) + else: + checkpoint = paddle.load(weight_path) + + self.generator = NAFNet( + img_channel=model_cfgs[task]['img_channel'], + width=model_cfgs[task]['width'], + enc_blk_nums=model_cfgs[task]['enc_blk_nums'], + middle_blk_num=model_cfgs[task]['middle_blk_num'], + dec_blk_nums=model_cfgs[task]['dec_blk_nums']) + + checkpoint = checkpoint['generator'] + self.generator.set_state_dict(checkpoint) + self.generator.eval() + + if seed is not None: + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + def get_images(self, images_path): + if os.path.isdir(images_path): + return natsorted( + glob(os.path.join(images_path, '*.jpeg')) + + glob(os.path.join(images_path, '*.jpg')) + + glob(os.path.join(images_path, '*.JPG')) + + glob(os.path.join(images_path, '*.png')) + + glob(os.path.join(images_path, '*.PNG'))) + else: + return [images_path] + + def imread_uint(self, path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + + return img + + def uint2single(self, img): + + return np.float32(img / 255.) + + # convert single (HxWxC) to 3-dimensional paddle tensor + def single2tensor3(self, img): + return paddle.Tensor(np.ascontiguousarray( + img, dtype=np.float32)).transpose([2, 0, 1]) + + def run(self, images_path=None): + os.makedirs(self.output_path, exist_ok=True) + task_path = os.path.join(self.output_path, self.task) + os.makedirs(task_path, exist_ok=True) + image_files = self.get_images(images_path) + for image_file in tqdm(image_files): + img_L = self.imread_uint(image_file, 3) + + image_name = os.path.basename(image_file) + img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(task_path, image_name), img) + + tmps = image_name.split('.') + assert len( + tmps) == 2, f'Invalid image name: {image_name}, too much "."' + restoration_save_path = os.path.join( + task_path, f'{tmps[0]}_restoration.{tmps[1]}') + + img_L = self.uint2single(img_L) + + # HWC to CHW, numpy to tensor + img_L = self.single2tensor3(img_L) + img_L = img_L.unsqueeze(0) + with paddle.no_grad(): + output = self.generator(img_L) + + restored = paddle.clip(output, 0, 1) + + restored = restored.numpy() + restored = restored.transpose(0, 2, 3, 1) + restored = restored[0] + restored = restored * 255 + restored = restored.astype(np.uint8) + + cv2.imwrite(restoration_save_path, + cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) + + print('Done, output path is:', task_path) diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py old mode 100755 new mode 100644 index e660d8d..e1527de --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -35,4 +35,5 @@ from .swinir_dataset import SwinIRDataset from .gfpgan_datasets import FFHQDegradationDataset from .paired_image_datasets import PairedImageDataset from .invdn_dataset import InvDNDataset +from .nafnet_dataset import NAFNetTrain, NAFNetVal, NAFNetTest from .aotgan_dataset import AOTGANDataset diff --git a/ppgan/datasets/nafnet_dataset.py b/ppgan/datasets/nafnet_dataset.py new file mode 100644 index 0000000..88cab70 --- /dev/null +++ b/ppgan/datasets/nafnet_dataset.py @@ -0,0 +1,193 @@ +# code was heavily based on https://github.com/swz30/MPRNet +# Users should be careful about adopting these functions in any commercial matters. +# https://github.com/swz30/MPRNet/blob/main/LICENSE.md + +import os +import random +import numpy as np +from PIL import Image + +from paddle.io import Dataset +from .builder import DATASETS +from paddle.vision.transforms.functional import to_tensor, adjust_brightness, adjust_saturation, rotate, hflip, hflip, vflip, center_crop + + +def is_image_file(filename): + return any( + filename.endswith(extension) + for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) + + +@DATASETS.register() +class NAFNetTrain(Dataset): + + def __init__(self, rgb_dir, img_options=None): + super(NAFNetTrain, self).__init__() + + inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) + tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) + + self.inp_filenames = [ + os.path.join(rgb_dir, 'input', x) for x in inp_files + if is_image_file(x) + ] + self.tar_filenames = [ + os.path.join(rgb_dir, 'target', x) for x in tar_files + if is_image_file(x) + ] + + self.img_options = img_options + self.sizex = len(self.tar_filenames) # get the size of target + + self.ps = self.img_options['patch_size'] + + def __len__(self): + return self.sizex + + def __getitem__(self, index): + index_ = index % self.sizex + ps = self.ps + + inp_path = self.inp_filenames[index_] + tar_path = self.tar_filenames[index_] + + inp_img = Image.open(inp_path) + tar_img = Image.open(tar_path) + + w, h = tar_img.size + padw = ps - w if w < ps else 0 + padh = ps - h if h < ps else 0 + + # Reflect Pad in case image is smaller than patch_size + if padw != 0 or padh != 0: + inp_img = np.pad(inp_img, (0, 0, padw, padh), + padding_mode='reflect') + tar_img = np.pad(tar_img, (0, 0, padw, padh), + padding_mode='reflect') + + aug = random.randint(0, 2) + if aug == 1: + inp_img = adjust_brightness(inp_img, 1) + tar_img = adjust_brightness(tar_img, 1) + + aug = random.randint(0, 2) + if aug == 1: + sat_factor = 1 + (0.2 - 0.4 * np.random.rand()) + inp_img = adjust_saturation(inp_img, sat_factor) + tar_img = adjust_saturation(tar_img, sat_factor) + + # Data Augmentations + aug = random.randint(0, 8) + if aug == 1: + inp_img = vflip(inp_img) + tar_img = vflip(tar_img) + elif aug == 2: + inp_img = hflip(inp_img) + tar_img = hflip(tar_img) + elif aug == 3: + inp_img = rotate(inp_img, 90) + tar_img = rotate(tar_img, 90) + elif aug == 4: + inp_img = rotate(inp_img, 90 * 2) + tar_img = rotate(tar_img, 90 * 2) + elif aug == 5: + inp_img = rotate(inp_img, 90 * 3) + tar_img = rotate(tar_img, 90 * 3) + elif aug == 6: + inp_img = rotate(vflip(inp_img), 90) + tar_img = rotate(vflip(tar_img), 90) + elif aug == 7: + inp_img = rotate(hflip(inp_img), 90) + tar_img = rotate(hflip(tar_img), 90) + + inp_img = to_tensor(inp_img) + tar_img = to_tensor(tar_img) + + hh, ww = tar_img.shape[1], tar_img.shape[2] + + rr = random.randint(0, hh - ps) + cc = random.randint(0, ww - ps) + + # Crop patch + inp_img = inp_img[:, rr:rr + ps, cc:cc + ps] + tar_img = tar_img[:, rr:rr + ps, cc:cc + ps] + + filename = os.path.splitext(os.path.split(tar_path)[-1])[0] + + return tar_img, inp_img, filename + + +@DATASETS.register() +class NAFNetVal(Dataset): + + def __init__(self, rgb_dir, img_options=None, rgb_dir2=None): + super(NAFNetVal, self).__init__() + + inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) + tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) + + self.inp_filenames = [ + os.path.join(rgb_dir, 'input', x) for x in inp_files + if is_image_file(x) + ] + self.tar_filenames = [ + os.path.join(rgb_dir, 'target', x) for x in tar_files + if is_image_file(x) + ] + + self.img_options = img_options + self.sizex = len(self.tar_filenames) # get the size of target + + self.ps = self.img_options['patch_size'] + + def __len__(self): + return self.sizex + + def __getitem__(self, index): + index_ = index % self.sizex + ps = self.ps + + inp_path = self.inp_filenames[index_] + tar_path = self.tar_filenames[index_] + + inp_img = Image.open(inp_path) + tar_img = Image.open(tar_path) + + # Validate on center crop + if self.ps is not None: + inp_img = center_crop(inp_img, (ps, ps)) + tar_img = center_crop(tar_img, (ps, ps)) + + inp_img = to_tensor(inp_img) + tar_img = to_tensor(tar_img) + + filename = os.path.splitext(os.path.split(tar_path)[-1])[0] + + return tar_img, inp_img, filename + + +@DATASETS.register() +class NAFNetTest(Dataset): + + def __init__(self, inp_dir, img_options): + super(NAFNetTest, self).__init__() + + inp_files = sorted(os.listdir(inp_dir)) + self.inp_filenames = [ + os.path.join(inp_dir, x) for x in inp_files if is_image_file(x) + ] + + self.inp_size = len(self.inp_filenames) + self.img_options = img_options + + def __len__(self): + return self.inp_size + + def __getitem__(self, index): + + path_inp = self.inp_filenames[index] + filename = os.path.splitext(os.path.split(path_inp)[-1])[0] + inp = Image.open(path_inp) + + inp = to_tensor(inp) + return inp, filename diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index ffeb466..7229c61 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -41,4 +41,5 @@ from .gpen_model import GPENModel from .swinir_model import SwinIRModel from .gfpgan_model import GFPGANModel from .invdn_model import InvDNModel +from .nafnet_model import NAFNetModel from .aotgan_model import AOTGANModel diff --git a/ppgan/models/criterions/__init__.py b/ppgan/models/criterions/__init__.py index 219b72e..6d5bcd4 100644 --- a/ppgan/models/criterions/__init__.py +++ b/ppgan/models/criterions/__init__.py @@ -2,7 +2,7 @@ from .gan_loss import GANLoss from .perceptual_loss import PerceptualLoss from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss, \ CalcStyleEmdLoss, CalcContentReltLoss, \ - CalcContentLoss, CalcStyleLoss, EdgeLoss + CalcContentLoss, CalcStyleLoss, EdgeLoss, PSNRLoss from .photopen_perceptual_loss import PhotoPenPerceptualLoss from .gradient_penalty import GradientPenalty diff --git a/ppgan/models/criterions/pixel_loss.py b/ppgan/models/criterions/pixel_loss.py index ca60f55..1e10374 100644 --- a/ppgan/models/criterions/pixel_loss.py +++ b/ppgan/models/criterions/pixel_loss.py @@ -31,6 +31,7 @@ class L1Loss(): loss_weight (float): Loss weight for L1 loss. Default: 1.0. """ + def __init__(self, reduction='mean', loss_weight=1.0): # when loss weight less than zero return None if loss_weight <= 0: @@ -59,6 +60,7 @@ class CharbonnierLoss(): eps (float): Default: 1e-12. """ + def __init__(self, eps=1e-12, reduction='sum'): self.eps = eps self.reduction = reduction @@ -90,6 +92,7 @@ class MSELoss(): loss_weight (float): Loss weight for MSE loss. Default: 1.0. """ + def __init__(self, reduction='mean', loss_weight=1.0): # when loss weight less than zero return None if loss_weight <= 0: @@ -119,6 +122,7 @@ class BCEWithLogitsLoss(): Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. loss_weight (float): Loss weight for MSE loss. Default: 1.0. """ + def __init__(self, reduction='mean', loss_weight=1.0): # when loss weight less than zero return None if loss_weight <= 0: @@ -161,6 +165,7 @@ def calc_emd_loss(pred, target): class CalcStyleEmdLoss(): """Calc Style Emd Loss. """ + def __init__(self): super(CalcStyleEmdLoss, self).__init__() @@ -183,6 +188,7 @@ class CalcStyleEmdLoss(): class CalcContentReltLoss(): """Calc Content Relt Loss. """ + def __init__(self): super(CalcContentReltLoss, self).__init__() @@ -207,6 +213,7 @@ class CalcContentReltLoss(): class CalcContentLoss(): """Calc Content Loss. """ + def __init__(self): self.mse_loss = nn.MSELoss() @@ -229,6 +236,7 @@ class CalcContentLoss(): class CalcStyleLoss(): """Calc Style Loss. """ + def __init__(self): self.mse_loss = nn.MSELoss() @@ -247,6 +255,7 @@ class CalcStyleLoss(): @CRITERIONS.register() class EdgeLoss(): + def __init__(self): k = paddle.to_tensor([[.05, .25, .4, .25, .05]]) self.kernel = paddle.matmul(k.t(), k).unsqueeze(0).tile([3, 1, 1, 1]) @@ -271,3 +280,27 @@ class EdgeLoss(): y.stop_gradient = True loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) return loss + + +@CRITERIONS.register() +class PSNRLoss(nn.Layer): + + def __init__(self, loss_weight=1.0, reduction='mean', toY=False): + super(PSNRLoss, self).__init__() + assert reduction == 'mean' + self.loss_weight = loss_weight + self.scale = 10 / np.log(10) + self.toY = toY + self.coef = paddle.to_tensor(np.array([65.481, 128.553, + 24.966])).reshape([1, 3, 1, 1]) + + def forward(self, pred, target): + if self.toY: + pred = (pred * self.coef).sum(axis=1).unsqueeze(axis=1) + 16. + target = (target * self.coef).sum(axis=1).unsqueeze(axis=1) + 16. + + pred, target = pred / 255., target / 255. + pass + + return self.loss_weight * self.scale * paddle.log(( + (pred - target)**2).mean(axis=[1, 2, 3]) + 1e-8).mean() diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py old mode 100755 new mode 100644 index ba5a730..7b00707 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -46,4 +46,5 @@ from .swinir import SwinIR from .gfpganv1_clean_arch import GFPGANv1Clean from .gfpganv1_arch import GFPGANv1, StyleGAN2DiscriminatorGFPGAN from .invdn import InvDN +from .nafnet import NAFNet, NAFNetLocal from .generater_aotgan import InpaintGenerator diff --git a/ppgan/models/generators/nafnet.py b/ppgan/models/generators/nafnet.py new file mode 100644 index 0000000..fa17bcd --- /dev/null +++ b/ppgan/models/generators/nafnet.py @@ -0,0 +1,407 @@ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors + +import paddle +from paddle import nn as nn +import paddle.nn.functional as F +from paddle.autograd import PyLayer + +from .builder import GENERATORS + + +class LayerNormFunction(PyLayer): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.shape + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.reshape([1, C, 1, 1]) * y + bias.reshape([1, C, 1, 1]) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.shape + y, var, weight = ctx.saved_tensor() + g = grad_output * weight.reshape([1, C, 1, 1]) + mean_g = g.mean(axis=1, keepdim=True) + + mean_gy = (g * y).mean(axis=1, keepdim=True) + gx = 1. / paddle.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(axis=3).sum(axis=2).sum( + axis=0), grad_output.sum(axis=3).sum(axis=2).sum(axis=0) + + +class LayerNorm2D(nn.Layer): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2D, self).__init__() + self.add_parameter( + 'weight', + self.create_parameter( + [channels], + default_initializer=paddle.nn.initializer.Constant(value=1.0))) + self.add_parameter( + 'bias', + self.create_parameter( + [channels], + default_initializer=paddle.nn.initializer.Constant(value=0.0))) + self.eps = eps + + def forward(self, x): + if self.training: + y = LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + else: + N, C, H, W = x.shape + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + self.eps).sqrt() + y = self.weight.reshape([1, C, 1, 1]) * y + self.bias.reshape( + [1, C, 1, 1]) + + return y + + +class AvgPool2D(nn.Layer): + + def __init__(self, + kernel_size=None, + base_size=None, + auto_pad=True, + fast_imp=False, + train_size=None): + super().__init__() + self.kernel_size = kernel_size + self.base_size = base_size + self.auto_pad = auto_pad + + # only used for fast implementation + self.fast_imp = fast_imp + self.rs = [5, 4, 3, 2, 1] + self.max_r1 = self.rs[0] + self.max_r2 = self.rs[0] + self.train_size = train_size + + def extra_repr(self) -> str: + return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( + self.kernel_size, self.base_size, self.kernel_size, self.fast_imp) + + def forward(self, x): + if self.kernel_size is None and self.base_size: + train_size = self.train_size + if isinstance(self.base_size, int): + self.base_size = (self.base_size, self.base_size) + self.kernel_size = list(self.base_size) + self.kernel_size[ + 0] = x.shape[2] * self.base_size[0] // train_size[-2] + self.kernel_size[ + 1] = x.shape[3] * self.base_size[1] // train_size[-1] + + # only used for fast implementation + self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) + self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) + + if self.kernel_size[0] >= x.shape[-2] and self.kernel_size[ + 1] >= x.shape[-1]: + return F.adaptive_avg_pool2d(x, 1) + + if self.fast_imp: # Non-equivalent implementation but faster + h, w = x.shape[2:] + if self.kernel_size[0] >= h and self.kernel_size[1] >= w: + out = F.adaptive_avg_pool2d(x, 1) + else: + r1 = [r for r in self.rs if h % r == 0][0] + r2 = [r for r in self.rs if w % r == 0][0] + # reduction_constraint + r1 = min(self.max_r1, r1) + r2 = min(self.max_r2, r2) + s = x[:, :, ::r1, ::r2].cumsum(axis=-1).cumsum(axis=-2) + n, c, h, w = s.shape + k1, k2 = min(h - 1, self.kernel_size[0] // r1), min( + w - 1, self.kernel_size[1] // r2) + out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - + s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) + out = paddle.nn.functional.interpolate(out, + scale_factor=(r1, r2)) + else: + n, c, h, w = x.shape + s = x.cumsum(axis=-1).cumsum(axis=-2) + s = paddle.nn.functional.pad(s, + [1, 0, 1, 0]) # pad 0 for convenience + k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) + s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, + k2:], s[:, :, + k1:, :-k2], s[:, :, + k1:, + k2:] + out = s4 + s1 - s2 - s3 + out = out / (k1 * k2) + + if self.auto_pad: + n, c, h, w = x.shape + _h, _w = out.shape[2:] + pad2d = [(w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, + (h - _h + 1) // 2] + out = paddle.nn.functional.pad(out, pad2d, mode='replicate') + + return out + + +def replace_layers(model, base_size, train_size, fast_imp, **kwargs): + for n, m in model.named_children(): + if len(list(m.children())) > 0: + ## compound module, go inside it + replace_layers(m, base_size, train_size, fast_imp, **kwargs) + + if isinstance(m, nn.AdaptiveAvgPool2D): + pool = AvgPool2D(base_size=base_size, + fast_imp=fast_imp, + train_size=train_size) + assert m._output_size == 1 + setattr(model, n, pool) + + +''' +ref. +@article{chu2021tlsc, + title={Revisiting Global Statistics Aggregation for Improving Image Restoration}, + author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin}, + journal={arXiv preprint arXiv:2112.04491}, + year={2021} +} +''' + + +class Local_Base(): + + def convert(self, *args, train_size, **kwargs): + replace_layers(self, *args, train_size=train_size, **kwargs) + imgs = paddle.rand(train_size) + with paddle.no_grad(): + self.forward(imgs) + + +class SimpleGate(nn.Layer): + + def forward(self, x): + x1, x2 = x.chunk(2, axis=1) + return x1 * x2 + + +class NAFBlock(nn.Layer): + + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2D(in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias_attr=True) + self.conv2 = nn.Conv2D(in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias_attr=True) + self.conv3 = nn.Conv2D(in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias_attr=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2D(1), + nn.Conv2D(in_channels=dw_channel // 2, + out_channels=dw_channel // 2, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias_attr=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2D(in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias_attr=True) + self.conv5 = nn.Conv2D(in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias_attr=True) + + self.norm1 = LayerNorm2D(c) + self.norm2 = LayerNorm2D(c) + + self.drop_out_rate = drop_out_rate + + self.dropout1 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else None + self.dropout2 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else None + + self.add_parameter( + "beta", + self.create_parameter( + [1, c, 1, 1], + default_initializer=paddle.nn.initializer.Constant(value=0.0))) + self.add_parameter( + "gamma", + self.create_parameter( + [1, c, 1, 1], + default_initializer=paddle.nn.initializer.Constant(value=0.0))) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + if self.drop_out_rate > 0: + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + if self.drop_out_rate > 0: + x = self.dropout2(x) + + return y + x * self.gamma + + +@GENERATORS.register() +class NAFNet(nn.Layer): + + def __init__(self, + img_channel=3, + width=16, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[]): + super().__init__() + + self.intro = nn.Conv2D(in_channels=img_channel, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias_attr=True) + self.ending = nn.Conv2D(in_channels=width, + out_channels=img_channel, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias_attr=True) + + self.encoders = nn.LayerList() + self.decoders = nn.LayerList() + self.middle_blks = nn.LayerList() + self.ups = nn.LayerList() + self.downs = nn.LayerList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) + self.downs.append(nn.Conv2D(chan, 2 * chan, 2, 2)) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential(nn.Conv2D(chan, chan * 2, 1, bias_attr=False), + nn.PixelShuffle(2))) + chan = chan // 2 + self.decoders.append( + nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) + + self.padder_size = 2**len(self.encoders) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.shape + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, [0, mod_pad_w, 0, mod_pad_h]) + return x + + +@GENERATORS.register() +class NAFNetLocal(Local_Base, NAFNet): + + def __init__(self, + *args, + train_size=(1, 3, 256, 256), + fast_imp=False, + **kwargs): + Local_Base.__init__(self) + NAFNet.__init__(self, *args, **kwargs) + + N, C, H, W = train_size + base_size = (int(H * 1.5), int(W * 1.5)) + + self.eval() + with paddle.no_grad(): + self.convert(base_size=base_size, + train_size=train_size, + fast_imp=fast_imp) diff --git a/ppgan/models/nafnet_model.py b/ppgan/models/nafnet_model.py new file mode 100644 index 0000000..d4c7cf1 --- /dev/null +++ b/ppgan/models/nafnet_model.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import paddle.nn as nn + +from .builder import MODELS +from .base_model import BaseModel +from .generators.builder import build_generator +from .criterions.builder import build_criterion +from ..utils.visual import tensor2img + + +@MODELS.register() +class NAFNetModel(BaseModel): + """NAFNet Model. + + Paper: Simple Baselines for Image Restoration + https://arxiv.org/pdf/2204.04676 + """ + + def __init__(self, generator, psnr_criterion=None): + """Initialize the MPR class. + + Args: + generator (dict): config of generator. + psnr_criterion (dict): config of psnr criterion. + """ + super(NAFNetModel, self).__init__(generator) + self.current_iter = 1 + + self.nets['generator'] = build_generator(generator) + + if psnr_criterion: + self.psnr_criterion = build_criterion(psnr_criterion) + + def setup_input(self, input): + self.target = input[0] + self.lq = input[1] + + def train_iter(self, optims=None): + optims['optim'].clear_gradients() + + restored = self.nets['generator'](self.lq) + + loss = self.psnr_criterion(restored, self.target) + + loss.backward() + optims['optim'].step() + self.losses['loss'] = loss.numpy() + + def forward(self): + pass + + def test_iter(self, metrics=None): + self.nets['generator'].eval() + with paddle.no_grad(): + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + self.nets['generator'].train() + + out_img = [] + gt_img = [] + for out_tensor, gt_tensor in zip(self.output, self.target): + out_img.append(tensor2img(out_tensor, (0., 1.))) + gt_img.append(tensor2img(gt_tensor, (0., 1.))) + + if metrics is not None: + for metric in metrics.values(): + metric.update(out_img, gt_img) + + def export_model(self, + export_model=None, + output_dir=None, + inputs_size=None, + export_serving_model=False, + model_name=None): + shape = inputs_size[0] + new_model = self.nets['generator'] + new_model.eval() + input_spec = [paddle.static.InputSpec(shape=shape, dtype="float32")] + + static_model = paddle.jit.to_static(new_model, input_spec=input_spec) + + if output_dir is None: + output_dir = 'inference_model' + if model_name is None: + model_name = '{}_{}'.format(self.__class__.__name__.lower(), + export_model[0]['name']) + + paddle.jit.save(static_model, os.path.join(output_dir, model_name)) diff --git a/ppgan/solver/optimizer.py b/ppgan/solver/optimizer.py index 36345b5..bc9c48d 100644 --- a/ppgan/solver/optimizer.py +++ b/ppgan/solver/optimizer.py @@ -21,3 +21,4 @@ OPTIMIZERS.register(paddle.optimizer.Adam) OPTIMIZERS.register(paddle.optimizer.SGD) OPTIMIZERS.register(paddle.optimizer.Momentum) OPTIMIZERS.register(paddle.optimizer.RMSProp) +OPTIMIZERS.register(paddle.optimizer.AdamW) diff --git a/test_tipc/configs/nafnet/train_infer_python.txt b/test_tipc/configs/nafnet/train_infer_python.txt new file mode 100644 index 0000000..8cd3ddc --- /dev/null +++ b/test_tipc/configs/nafnet/train_infer_python.txt @@ -0,0 +1,51 @@ +===========================train_params=========================== +model_name:nafnet +python:python3.7 +gpu_list:0 +## +auto_cast:null +total_iters:lite_train_lite_infer=10 +output_dir:./output/ +snapshot_config.interval:lite_train_lite_infer=10 +pretrained_model:null +train_model_name:nafnet*/*checkpoint.pdparams +train_infer_img_dir:null +null:null +## +trainer:norm_train +norm_train:tools/main.py -c configs/nafnet_denoising.yaml --seed 100 -o log_config.interval=1 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +--output_dir:./output/ +load:null +norm_export:tools/export_model.py -c configs/nafnet_denoising.yaml --inputs_size=1,3,256,256 --model_name inference --load +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +inference_dir:inference +train_model:./inference/nafnet/nafnetmodel_generator +infer_export:null +infer_quant:False +inference:tools/inference.py --model_type nafnet --seed 100 -c configs/nafnet_denoising.yaml --output_path test_tipc/output/ +--device:gpu +null:null +null:null +null:null +null:null +null:null +--model_path: +null:null +null:null +--benchmark:True +null:null diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index b29b2ef..c172b96 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -70,6 +70,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then rm -rf ./data/SIDD_* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/SIDD_mini.zip --no-check-certificate cd ./data/ && unzip -q SIDD_mini.zip && cd ../ ;; + nafnet) + rm -rf ./data/SIDD* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/SIDD_mini.zip --no-check-certificate + cd ./data/ && unzip -q SIDD_mini.zip && mkdir -p SIDD && mv ./SIDD_Medium_Srgb_Patches_512/* ./SIDD/ \ + && mv ./SIDD_Valid_Srgb_Patches_256/* ./SIDD/ && mv ./SIDD/valid ./SIDD/val \ + && mv ./SIDD/train/GT ./SIDD/train/target && mv ./SIDD/train/Noisy ./SIDD/train/input \ + && mv ./SIDD/val/Noisy ./SIDD/val/input && mv ./SIDD/val/GT ./SIDD/val/target && cd ../ ;; singan) rm -rf ./data/singan* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate diff --git a/tools/inference.py b/tools/inference.py index 536b7b3..11ea4f4 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -19,7 +19,8 @@ from ppgan.metrics import build_metric MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \ - "edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan", "swinir", "invdn", "aotgan"] + "edvr", "fom", "stylegan2", "basicvsr", "msvsr", \ + "singan", "swinir", "invdn", "aotgan", "nafnet"] def parse_args(): @@ -423,8 +424,31 @@ def main(): for metric in metrics.values(): metric.update(image_numpy, gt_numpy) break + + elif model_type == "nafnet": + lq = data[1].numpy() + input_handles[0].copy_from_cpu(lq) + predictor.run() + prediction = output_handle.copy_to_cpu() + prediction = paddle.to_tensor(prediction) + target = tensor2img(data[0], (0., 1.)) + prediction = tensor2img(prediction, (0., 1.)) + + metric_file = os.path.join(args.output_path, model_type, + "metric.txt") + for metric in metrics.values(): + metric.update(prediction, target) + + lq = tensor2img(data[1], (0., 1.)) + + sample_result = np.concatenate((lq, prediction, target), 1) + sample = cv2.cvtColor(sample_result, cv2.COLOR_RGB2BGR) + file_name = os.path.join(args.output_path, model_type, + "{}.png".format(i)) + cv2.imwrite(file_name, sample) elif model_type == 'aotgan': - input_data = paddle.concat((data['img'], data['mask']), axis=1).numpy() + input_data = paddle.concat((data['img'], data['mask']), + axis=1).numpy() input_handles[0].copy_from_cpu(input_data) predictor.run() prediction = output_handle.copy_to_cpu() -- GitLab