未验证 提交 2b28df6e 编写于 作者: W wangna11BD 提交者: GitHub

add mpr predictor and doc (#530)

* add mpr predictor and doc

* add doc

* fix result

* fix Results

* fix copyright
上级 49ef722c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import os
import sys
sys.path.insert(0, os.getcwd())
from ppgan.apps import AnimeGANPredictor
import argparse
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import paddle
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import os
import sys
sys.path.insert(0, os.getcwd())
from ppgan.apps import MPRPredictor
import argparse
......@@ -22,18 +37,18 @@ if __name__ == "__main__":
default=None,
help="sample random seed for model's image generation")
parser.add_argument('--images_path',
default=None,
parser.add_argument('--images_path',
default=None,
required=True,
type=str,
type=str,
help='Single image or images directory.')
parser.add_argument('--task',
required=True,
type=str,
help='Task to run',
parser.add_argument('--task',
required=True,
type=str,
help='Task to run',
choices=['Deblurring', 'Denoising', 'Deraining'])
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
......@@ -44,11 +59,9 @@ if __name__ == "__main__":
if args.cpu:
paddle.set_device('cpu')
predictor = MPRPredictor(
images_path=args.images_path,
output_path=args.output_path,
weight_path=args.weight_path,
seed=args.seed,
task=args.task
)
predictor = MPRPredictor(images_path=args.images_path,
output_path=args.output_path,
weight_path=args.weight_path,
seed=args.seed,
task=args.task)
predictor.run()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import os
import sys
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import os
import sys
sys.path.insert(0, os.getcwd())
from ppgan.apps import MPRPredictor
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_image", type=str, help="path to image")
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model weight path")
parser.add_argument(
"--task",
type=str,
default='Deblurring',
help="task can be chosen in 'Deblurring', 'Denoising', 'Deraining'")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = MPRPredictor(output_path=args.output_path,
task=args.task,
weight_path=args.weight_path)
predictor.run(args.input_image)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import os
import sys
......@@ -39,10 +53,9 @@ if __name__ == "__main__":
if args.cpu:
paddle.set_device('cpu')
cfg = get_config(args.config_file)
predictor = PhotoPenPredictor(output_path=args.output_path,
weight_path=args.weight_path,
gen_cfg=cfg.predict)
weight_path=args.weight_path,
gen_cfg=cfg.predict)
predictor.run(semantic_label_path=args.semantic_label_path)
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import paddle
......@@ -103,11 +117,10 @@ parser.add_argument(
type=str,
default='sfd',
help="face detector to be used, can choose s3fd or blazeface")
parser.add_argument(
"--face_enhancement",
dest="face_enhancement",
action="store_true",
help="use face enhance for face")
parser.add_argument("--face_enhancement",
dest="face_enhancement",
action="store_true",
help="use face enhance for face")
parser.set_defaults(face_enhancement=False)
if __name__ == "__main__":
......@@ -115,17 +128,17 @@ if __name__ == "__main__":
if args.cpu:
paddle.set_device('cpu')
predictor = Wav2LipPredictor(checkpoint_path = args.checkpoint_path,
static = args.static,
fps = args.fps,
pads = args.pads,
face_det_batch_size = args.face_det_batch_size,
wav2lip_batch_size = args.wav2lip_batch_size,
resize_factor = args.resize_factor,
crop = args.crop,
box = args.box,
rotate = args.rotate,
nosmooth = args.nosmooth,
face_detector = args.face_detector,
face_enhancement = args.face_enhancement)
predictor = Wav2LipPredictor(checkpoint_path=args.checkpoint_path,
static=args.static,
fps=args.fps,
pads=args.pads,
face_det_batch_size=args.face_det_batch_size,
wav2lip_batch_size=args.wav2lip_batch_size,
resize_factor=args.resize_factor,
crop=args.crop,
box=args.box,
rotate=args.rotate,
nosmooth=args.nosmooth,
face_detector=args.face_detector,
face_enhancement=args.face_enhancement)
predictor.run(args.face, args.audio, args.outfile)
......@@ -6,6 +6,9 @@ model:
name: MPRModel
generator:
name: MPRNet
n_feat: 96
scale_unetfeats: 48
scale_orsnetfeats: 32
char_criterion:
name: CharbonnierLoss
......@@ -15,16 +18,16 @@ model:
dataset:
train:
name: MPRTrain
rgb_dir: 'data/GoPro/train'
rgb_dir: data/GoPro/train
num_workers: 4
batch_size: 2
batch_size: 2 # 8GPUs
img_options:
patch_size: 256
test:
name: MPRVal
rgb_dir: 'data/GoPro/test'
num_workers: 4
batch_size: 2
rgb_dir: data/GoPro/test
num_workers: 1
batch_size: 1
img_options:
patch_size: 256
......
total_iters: 100000
output_dir: output_dir
model:
name: MPRModel
generator:
name: MPRNet
n_feat: 80
scale_unetfeats: 48
scale_orsnetfeats: 32
char_criterion:
name: CharbonnierLoss
edge_criterion:
name: EdgeLoss
dataset:
train:
name: MPRTrain
rgb_dir: data/SIDD/train
num_workers: 16
batch_size: 4 # 4GPUs
img_options:
patch_size: 256
test:
name: MPRTrain
rgb_dir: data/SIDD/val
num_workers: 1
batch_size: 1
img_options:
patch_size: 256
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [25000, 25000, 25000, 25000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-6
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
total_iters: 100000
output_dir: output_dir
model:
name: MPRModel
generator:
name: MPRNet
n_feat: 40
scale_unetfeats: 20
scale_orsnetfeats: 16
char_criterion:
name: CharbonnierLoss
edge_criterion:
name: EdgeLoss
dataset:
train:
name: MPRTrain
rgb_dir: data/Synthetic_Rain_Datasets/train
num_workers: 16
batch_size: 4 # 4GPUs
img_options:
patch_size: 256
test:
name: MPRTrain
rgb_dir: data/Synthetic_Rain_Datasets/test/Rain100L
num_workers: 1
batch_size: 1
img_options:
patch_size: 256
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [25000, 25000, 25000, 25000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-6
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
# MPR_Net
## 1 Introduction
[MPR_Net](https://arxiv.org/abs/2102.02808) is an image restoration method published in CVPR2021. Image restoration tasks demand a complex balance between spatial details and high-level contextualized information while recovering images. MPR_Net propose a novel synergistic design that can optimally balance these competing goals. The main proposal is a multi-stage architecture, that progressively learns restoration functions for the degraded inputs, thereby breaking down the overall recovery process into more manageable steps. Specifically, the model first learns the contextualized features using encoder-decoder architectures and later combines them with a high-resolution branch that retains local information. At each stage, MPR_Net introduce a novel per-pixel adaptive design that leverages in-situ supervised attention to reweight the local features. A key ingredient in such a multi-stage architecture is the information exchange between different stages. To this end, MPR_Net propose a two-faceted approach where the information is not only exchanged sequentially from early to late stages, but lateral connections between feature processing blocks also exist to avoid any loss of information. The resulting tightly interlinked multi-stage architecture, named as MPRNet, delivers strong performance gains on ten datasets across a range of tasks including image deraining, deblurring, and denoising.
## 2 How to use
### 2.1 Quick start
After installing PaddleGAN, you can run python code as follows to generate the restorated image. Where the `task` is the type of restoration method, you can chose in `Deblurring``Denoising` and `Deraining`, and `PATH_OF_IMAGE`is your image path.
```python
from ppgan.apps import MPRPredictor
predictor = MPRPredictor(task='Deblurring')
predictor.run(PATH_OF_IMAGE)
```
Or run such a command to get the same result:
```sh
python applications/tools/mprnet.py --input_image ${PATH_OF_IMAGE} --task Deblurring
```
Where the `task` is the type of restoration method, you can chose in `Deblurring``Denoising` and `Deraining`, and `PATH_OF_IMAGE`is your image path.
### 2.1 Prepare dataset
The Deblurring training datasets is GoPro. The GoPro datasets used for deblurring consists of 3214 blurred images with a size of 1,280×720. These images are divided into 2103 training images and 1111 test images. It can be downloaded from [here](https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view?usp=sharing).
After downloading, decompress it to the data directory. After decompression, the structure of `GoProdataset` is as following:
```sh
GoPro
├── train
│ ├── input
│ └── target
└── test
├── input
└── target
```
The Denoising training datasets is SIDD, an image denoising datasets, containing 30,000 noisy images from 10 different lighting conditions, which can be downloaded from [training datasets](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) and [Test datasets](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su).
After downloading, decompress it to the data directory. After decompression, the structure of `SIDDdataset` is as following:
```sh
SIDD
├── train
│ ├── input
│ └── target
└── val
├── input
└── target
```
Deraining training datasets is Synthetic Rain Datasets, which consists of 13,712 clean rain image pairs collected from multiple datasets (Rain14000, Rain1800, Rain800, Rain12), which can be downloaded from [training datasets](https://drive.google.com/drive/folders/1Hnnlc5kI0v9_BtfMytC2LR5VpLAFZtVe) and [Test datasets](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs).
After downloading, decompress it to the data directory. After decompression, the structure of `Synthetic_Rain_Datasets` is as following:
```sh
Synthetic_Rain_Datasets
├── train
│ ├── input
│ └── target
└── test
├── Test100
├── Rain100H
├── Rain100L
├── Test1200
└── Test2800
```
### 2.2 Training
An example is training to deblur. If you want to train for other tasks, you can replace the config file.
```sh
python -u tools/main.py --config-file configs/mprnet_deblurring.yaml
```
### 2.3 Test
test model:
```sh
python tools/main.py --config-file configs/mprnet_deblurring.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
Deblurring
| model | dataset | PSNR/SSIM |
|---|---|---|
| MPRNet | GoPro | 33.4360/0.9410 |
Denoising
| model | dataset | PSNR/SSIM |
|---|---|---|
| MPRNet | SIDD | 43.6100 / 0.9586 |
Deraining
| model | dataset | PSNR/SSIM |
|---|---|---|
| MPRNet | Rain100L | 36.2848 / 0.9651 |
## 4 Download
| model | link |
|---|---|
| MPR_Deblurring | [MPR_Deblurring](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) |
| MPR_Denoising | [MPR_Denoising](https://paddlegan.bj.bcebos.com/models/MPR_Denoising.pdparams) |
| MPR_Deraining | [MPR_Deraining](https://paddlegan.bj.bcebos.com/models/MPR_Deraining.pdparams) |
# References
- [Multi-Stage Progressive Image Restoration](https://arxiv.org/abs/2102.02808)
```
@inproceedings{Kim2020U-GAT-IT:,
title={Multi-Stage Progressive Image Restoration},
author={Syed Waqas Zamir and Aditya Arora and Salman Khan and Munawar Hayat and Fahad Shahbaz Khan and Ming-Hsuan Yang and Ling Shao},
booktitle={CVPR},
year={2021}
}
```
......@@ -130,11 +130,6 @@ The metrics are PSNR / SSIM.
| pan_x4 | 30.4574 / 0.8643 | 26.7204 / 0.7434 | 28.9187 / 0.8176 |
| drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - |
Deblur models zoo
| model | GoPro | Download Link |
|---|---|---|
| MPRNet | 33.4360 / 0.9410 | [link](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) |
<!-- ![](../../imgs/horse2zebra.png) -->
......
# MPR_Net
## 1 原理介绍
[MPR_Net](https://arxiv.org/abs/2102.02808)是发表在CVPR2021的一篇图像修复方法。图像修复任务需要在恢复图像时在空间细节和高级上下文信息之间实现复杂的平衡。MPR_Net提出了一种新颖的协同设计,可以最佳地平衡这些相互竞争的目标。其中主要提议是一个多阶段架构,它逐步学习退化输入的恢复函数,从而将整个恢复过程分解为更易于管理的步骤。具体来说,MPR_Net首先使用编码器-解码器架构学习上下文特征,然后将它们与保留本地信息的高分辨率分支相结合。在每个阶段引入了一种新颖的每像素自适应设计,利用原位监督注意力来重新加权局部特征。这种多阶段架构的一个关键要素是不同阶段之间的信息交换。为此,MPR_Net提出了一种双向方法,其中信息不仅从早期到后期按顺序交换,而且特征处理块之间也存在横向连接以避免任何信息丢失。由此产生的紧密互连的多阶段架构,称为MPRNet,在包括图像去雨、去模糊和去噪在内的一系列任务中,在十个数据集上提供了强大的性能提升。
## 2 如何使用
### 2.1 快速体验
安装`PaddleGAN`之后运行如下代码即生成修复后的图像`output_dir/Deblurring/image_name.png`,其中`task`为你想要修复的任务,可以在`Deblurring``Denoising``Deraining`中选择,`PATH_OF_IMAGE`为你需要转换的图像路径。
```python
from ppgan.apps import MPRPredictor
predictor = MPRPredictor(task='Deblurring')
predictor.run(PATH_OF_IMAGE)
```
或者在终端中运行如下命令,也可获得相同结果:
```sh
python applications/tools/mprnet.py --input_image ${PATH_OF_IMAGE} --task Deblurring
```
其中`task`为你想要修复的任务,可以在`Deblurring``Denoising``Deraining`中选择,`PATH_OF_IMAGE`为你需要转换的图像路径。
### 2.1 数据准备
Deblurring训练数据是GoPro,用于去模糊的GoPro数据集由3214张1,280×720大小的模糊图像组成,这些图像分为2103张训练图像和1111张测试图像。可以从[这里](https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view?usp=sharing)下载。
下载后解压到data目录下,解压完成后数据分布如下所示:
```sh
GoPro
├── train
│ ├── input
│ └── target
└── test
├── input
└── target
```
Denoising训练数据是SIDD,一个图像去噪数据集,包含来自10个不同光照条件下的3万幅噪声图像,可以从[训练数据集下载](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php)[测试数据集下载](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su)下载。
下载后解压到data目录下,解压完成后数据分布如下所示:
```sh
SIDD
├── train
│ ├── input
│ └── target
└── val
├── input
└── target
```
Deraining训练数据是Synthetic Rain Datasets,由13712张从多个数据集(Rain14000, Rain1800, Rain800, Rain12)收集的干净雨图像对组成,可以从[训练数据集下载](https://drive.google.com/drive/folders/1Hnnlc5kI0v9_BtfMytC2LR5VpLAFZtVe)[测试数据集下载](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs)下载。
下载后解压到data目录下,解压完成后数据分布如下所示:
```sh
Synthetic_Rain_Datasets
├── train
│ ├── input
│ └── target
└── test
├── Test100
├── Rain100H
├── Rain100L
├── Test1200
└── Test2800
```
### 2.2 训练
示例以训练Deblurring的数据为例。如果想训练其他任务可以通过替换配置文件。
```sh
python -u tools/main.py --config-file configs/mprnet_deblurring.yaml
```
### 2.3 测试
测试模型:
```sh
python tools/main.py --config-file configs/mprnet_deblurring.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 结果展示
去模糊
| 模型 | 数据集 | PSNR/SSIM |
|---|---|---|
| MPRNet | GoPro | 33.4360/0.9410 |
去噪
| 模型 | 数据集 | PSNR/SSIM |
|---|---|---|
| MPRNet | SIDD | 43.6100 / 0.9586 |
去雨
| 模型 | 数据集 | PSNR/SSIM |
|---|---|---|
| MPRNet | Rain100L | 36.2848 / 0.9651 |
## 4 模型下载
| 模型 | 下载地址 |
|---|---|
| MPR_Deblurring | [MPR_Deblurring](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) |
| MPR_Denoising | [MPR_Denoising](https://paddlegan.bj.bcebos.com/models/MPR_Denoising.pdparams) |
| MPR_Deraining | [MPR_Deraining](https://paddlegan.bj.bcebos.com/models/MPR_Deraining.pdparams) |
# 参考文献
- [Multi-Stage Progressive Image Restoration](https://arxiv.org/abs/2102.02808)
```
@inproceedings{Kim2020U-GAT-IT:,
title={Multi-Stage Progressive Image Restoration},
author={Syed Waqas Zamir and Aditya Arora and Salman Khan and Munawar Hayat and Fahad Shahbaz Khan and Ming-Hsuan Yang and Ling Shao},
booktitle={CVPR},
year={2021}
}
```
......@@ -120,11 +120,6 @@ paddle模型使用DIV2K数据集训练,torch模型使用df2k和DIV2K数据集
| paddle | 30.4574 / 0.8643 | 26.7204 / 0.7434 |
| torch | 30.2183 / 0.8643 | 26.8035 / 0.7445 |
去模糊模型
| 模型 | GoPro | 下载地址 |
|---|---|---|
| MPRNet | 33.4360 / 0.9410 | [链接](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) |
<!-- ![](../../imgs/horse2zebra.png) -->
......
......@@ -55,13 +55,11 @@ model_cfgs = {
class MPRPredictor(BasePredictor):
def __init__(self,
images_path=None,
output_path='output_dir',
weight_path=None,
seed=None,
task=None):
self.output_path = output_path
self.images_path = images_path
self.task = task
self.max_size = 640
self.img_multiple_of = 8
......@@ -108,7 +106,7 @@ class MPRPredictor(BasePredictor):
img = img.resize((dw, dh))
return img
def run(self):
def run(self, images_path=None):
os.makedirs(self.output_path, exist_ok=True)
task_path = os.path.join(self.output_path, self.task)
os.makedirs(task_path, exist_ok=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册