未验证 提交 8aafe376 编写于 作者: W wangna11BD 提交者: GitHub

Add msvsr (#496)

* fix benchmark dataset

* fix edvr and basic bug

* add PP-MSVSR

* add experiment results in docs

* fix ssim

* modif

* modif2
上级 6a31877f
......@@ -39,7 +39,7 @@ run_cmd="set -xe;
nvidia-docker run --name test_paddlegan -i \
--net=host \
--shm-size=1g \
--shm-size=128g \
-v $PWD:/workspace \
${ImageName} /bin/bash -c "${run_cmd}"
```
......
......@@ -59,7 +59,7 @@ dataset:
test:
name: VSRFolderDataset
# for udm10 dataset
# for UDM10 dataset
# lq_folder: data/udm10/BDx4
# gt_folder: data/udm10/GT
lq_folder: data/Vid4/BDx4
......@@ -67,7 +67,7 @@ dataset:
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
# for udm10 dataset
# for UDM10 dataset
# filename_tmpl: '{:04d}.png'
filename_tmpl: '{:08d}.png'
- name: ReadImageSequence
......
......@@ -23,8 +23,8 @@ dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4 # 6
batch_size: 2 # 4*2
num_workers: 4
batch_size: 2 #4 GPUs
dataset:
name: SRREDSMultipleGTDataset
mode: train
......
......@@ -43,7 +43,7 @@ dataset:
scale: 4
fix_random_seed: 10
num_workers: 3
batch_size: 4 # 8GUPs
batch_size: 4 # 8GPUs
test:
......
......@@ -42,7 +42,7 @@ dataset:
scale: 4
fix_random_seed: 10
num_workers: 3
batch_size: 4 # 8GUPs
batch_size: 4 # 8GPUs
test:
......
......@@ -46,7 +46,7 @@ dataset:
scale: 4
fix_random_seed: 10
num_workers: 3
batch_size: 4 # 8GUPs
batch_size: 4 # 8GPUs
test:
......
......@@ -42,7 +42,7 @@ dataset:
scale: 4
fix_random_seed: 10
num_workers: 3
batch_size: 4 # 8GUPs
batch_size: 4 # 8GPUs
test:
......
......@@ -23,8 +23,8 @@ dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4 # 6
batch_size: 2 # 4*2
num_workers: 4
batch_size: 2 #4 GPUs
dataset:
name: SRREDSMultipleGTDataset
mode: train
......
......@@ -32,7 +32,7 @@ dataset:
load_size: 136
crop_size: 128
num_workers: 16
batch_size: 5
batch_size: 5 #1 GPUs
test:
name: LapStyleDataset
content_root: data/coco/test2017/
......
......@@ -38,7 +38,7 @@ dataset:
load_size: 280
crop_size: 256
num_workers: 16
batch_size: 5
batch_size: 5 #1 GPUs
test:
name: LapStyleDataset
content_root: data/coco/test2017/
......
......@@ -38,7 +38,7 @@ dataset:
load_size: 540
crop_size: 512
num_workers: 16
batch_size: 2
batch_size: 2 #1 GPUs
test:
name: LapStyleDataset
content_root: data/coco/test2017/
......
total_iters: 300000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: MultiStageVSRModel
fix_iter: 2500
generator:
name: MSVSR
mid_channels: 64
num_init_blocks: 5
num_blocks: 7
num_reconstruction_blocks: 5
only_last: False
use_tiny_spynet: False
deform_groups: 8
stage1_groups: 8
auxiliary_loss: True
use_refine_align: True
aux_reconstruction_blocks: 2
use_local_connnect: True
pixel_criterion:
name: CharbonnierLoss
reduction: mean
dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4
batch_size: 2 #8 gpus
use_shared_memory: True
dataset:
name: SRREDSMultipleGTDataset
mode: train
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 30
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
test:
name: SRREDSMultipleGTDataset
mode: test
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [300000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: false
ssim:
name: SSIM
crop_border: 0
test_y_channel: false
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
total_iters: 150000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: MultiStageVSRModel
fix_iter: 2500
generator:
name: MSVSR
mid_channels: 32
num_init_blocks: 2
num_blocks: 3
num_reconstruction_blocks: 2
only_last: True
use_tiny_spynet: True
deform_groups: 4
stage1_groups: 8
auxiliary_loss: True
use_refine_align: True
aux_reconstruction_blocks: 1
use_local_connnect: True
pixel_criterion:
name: CharbonnierLoss
reduction: mean
dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 6
batch_size: 2 #8 gpus
use_shared_memory: True
dataset:
name: SRREDSMultipleGTDataset
mode: train
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 20
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
test:
name: SRREDSMultipleGTDataset
mode: test
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [150000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: false
ssim:
name: SSIM
crop_border: 0
test_y_channel: false
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
total_iters: 300000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: MultiStageVSRModel
fix_iter: -1
generator:
name: MSVSR
mid_channels: 32
num_init_blocks: 2
num_blocks: 3
num_reconstruction_blocks: 2
only_last: True
use_tiny_spynet: True
deform_groups: 4
stage1_groups: 8
auxiliary_loss: True
use_refine_align: True
aux_reconstruction_blocks: 1
use_local_connnect: True
pixel_criterion:
name: CharbonnierLoss
reduction: mean
dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4
batch_size: 2 #8 gpus
dataset:
name: VSRVimeo90KDataset
# mode: train
lq_folder: data/vimeo90k/vimeo_septuplet_BD_matlabLRx4/sequences
gt_folder: data/vimeo90k/vimeo_septuplet/sequences
ann_file: data/vimeo90k/vimeo_septuplet/sep_trainlist.txt
preprocess:
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: MirrorVideoSequence
- name: NormalizeSequence
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: VSRFolderDataset
# for udm10 dataset
# lq_folder: data/udm10/BDx4
# gt_folder: data/udm10/GT
lq_folder: data/Vid4/BDx4
gt_folder: data/Vid4/GT
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
# for udm10 dataset
# filename_tmpl: '{:04d}.png'
filename_tmpl: '{:08d}.png'
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [300000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 2500
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: true
ssim:
name: SSIM
crop_border: 0
test_y_channel: true
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 2500
......@@ -3,15 +3,22 @@
## 1.1 Principle
Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf).[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf).
Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf), [BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf), and PP-MSVSR.
[EDVR](https://arxiv.org/pdf/1905.02716.pdf) wins the champions and outperforms the second place by a large margin in all four tracks in the NTIRE19 video restoration and enhancement challenges. The main difficulties of video super-resolution from two aspects: (1) how to align multiple frames given large motions, and (2) how to effectively fuse different frames with diverse motion and blur. First, to handle large motions, EDVR devise a Pyramid, Cascading and Deformable (PCD) alignment module, in which frame alignment is done at the feature level using deformable convolutions in a coarse-to-fine manner. Second, EDVR propose a Temporal and Spatial Attention (TSA) fusion module, in which attention is applied both temporally and spatially, so as to emphasize important features for subsequent restoration.
[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf) reconsiders some most essential components for VSR guided by four basic functionalities, i.e., Propagation, Alignment, Aggregation, and Upsampling. By reusing some existing components added with minimal redesigns, a succinct pipeline, BasicVSR, achieves appealing improvements in terms of speed and restoration quality in comparison to many state-of-the-art algorithms. By presenting an informationrefill mechanism and a coupled propagation scheme to facilitate information aggregation, the BasicVSR can be expanded to [IconVSR](https://arxiv.org/pdf/2012.02181.pdf), which can serve as strong baselines for future VSR approaches.
[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf) redesign BasicVSR by proposing second-order grid propagation and flowguided deformable alignment. By empowering the recurrent framework with the enhanced propagation and alignment, BasicVSR++ can exploit spatiotemporal information across misaligned video frames more effectively. The new components lead to an improved performance under a similar computational constraint. In particular, BasicVSR++ surpasses BasicVSR by 0.82 dB in PSNR with similar number of parameters. In NTIRE 2021, BasicVSR++ obtains three champions and one runner-up in the Video Super-Resolution and Compressed Video Enhancement Challenges.
PP-MSVSR is a multi-stage VSR deep architecture, with local fusion module, auxiliary loss and refined align module to refine the enhanced result progressively. Specifically, in order to strengthen the fusion of features across frames in feature propagation, a local fusion module is designed in stage-1 to perform local feature fusion before feature propagation. Moreover, an auxiliary loss in stage-2 is introduced to make the features obtained by the propagation module reserve more correlated information connected to the HR space, and introduced a refined align module in stage-3 to make full use of the feature information of the previous stage. Extensive experiments substantiate that PP-MSVSR achieves a promising performance of Vid4 datasets, which PSNR metric can achieve 28.13 with only 1.45M parameters.
## 1.2 How to use
### 1.2.1 Prepare Datasets
Here are 4 commonly used video super-resolution dataset, REDS, Vimeo90K, Vid4, UDM10. The REDS and Vimeo90K dataset include train dataset and test dataset, Vid4 and UDM10 are test dataset. Download and decompress the required dataset and place it under the ``PaddleGAN/data``.
REDS([download](https://seungjunnah.github.io/Datasets/reds.html))is a newly proposed high-quality (720p) video dataset in the NTIRE19 Competition. REDS consists of 240 training clips, 30 validation clips and 30 testing clips (each with 100 consecutive frames). Since the test ground truth is not available, we select four representative clips (they are '000', '011', '015', '020', with diverse scenes and motions) as our test set, denoted by REDS4. The remaining training and validation clips are re-grouped as our training dataset (a total of 266 clips).
......@@ -31,6 +38,49 @@
...
```
Vimeo90K ([download](http://toflow.csail.mit.edu/)) is designed by Tianfan Xue etc. for the following four video processing tasks: temporal frame interpolation, video denoising, video deblocking, and video super-resolution. Vimeo90K is a large-scale, high-quality video dataset. This dataset consists of 89,800 video clips downloaded from vimeo.com, which covers large variaty of scenes and actions.
The structure of the processed Vimeo90K is as follows:
```
PaddleGAN
├── data
├── Vimeo90K
├── vimeo_septuplet
| |──sequences
| └──sep_trainlist.txt
├── vimeo_septuplet_BD_matlabLRx4
| └──sequences
└── vimeo_super_resolution_test
|──low_resolution
|──target
└──sep_testlist.txt
...
```
Vid4 ([Data Download](https://paddlegan.bj.bcebos.com/datasets/Vid4.zip)) is a commonly used test dataset for VSR, which contains 4 video segments.
The structure of the processed Vid4 is as follows:
```
PaddleGAN
├── data
├── Vid4
├── BDx4
└── GT
...
```
UDM10 ([Data Download](https://paddlegan.bj.bcebos.com/datasets/udm10_paddle.tar)) is a commonly used test dataset for VSR, which contains 10 video segments.
The structure of the processed UDM10 is as follows:
```
PaddleGAN
├── data
├── udm10
├── BDx4
└── GT
...
```
### 1.2.2 Train/Test
According to the number of channels, EDVR are divided into EDVR_L(128 channels) and EDVR_M (64 channels). Then, taking EDVR_M as an example, the model training and testing are introduced.
......@@ -63,24 +113,37 @@
python tools/main.py --config-file configs/edvr_m_w_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
To train or test other VSR model, you can find the config file of the corresponding VSR model in the ``PaddleGAN/configs``, then change the config file in the command to the config file of corresponding VSR model.
## 1.3 Results
The experimental results are evaluated on RGB channel.
The metrics are PSNR / SSIM.
VSR quantitative comparis on the test dataset REDS4 from REDS dataset
| Method | Paramete(M) | FLOPs(G) | REDS4 |
|---|---|---|---|
| EDVR_M_wo_tsa_SRx4 | 3.00 | 223 | 30.4429 / 0.8684 |
| EDVR_M_w_tsa_SRx4 | 3.30 | 232 | 30.5169 / 0.8699 |
| EDVR_L_wo_tsa_SRx4 | 19.42 | 974 | 30.8649 / 0.8761 |
| EDVR_L_w_tsa_SRx4 | 20.63 | 1010 | 30.9336 / 0.8773 |
| BasicVSR_x4 | 6.29 | 374 | 31.4325 / 0.8913 |
| IconVSR_x4 | 8.69 | 516 | 31.6882 / 0.8950 |
| BasicVSR++_x4 | 7.32 | 406 | 32.4018 / 0.9071 |
| PP-MSVSR_reds_x4 | 1.45 | 111 | 31.2535 / 0.8884 |
| PP-MSVSR-L_reds_x4 | 7.42 | 543 | 32.5321 / 0.9083 |
Deblur quantitative comparis on the test dataset REDS4 from REDS dataset
| Method | REDS4 |
|---|---|
| EDVR_M_wo_tsa_SRx4 | 30.4429 / 0.8684 |
| EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 |
| EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 |
| EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 |
| EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 |
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
| BasicVSR++_x4 | 32.4018 / 0.9071 |
VSR quantitative comparis on the Vimeo90K, Vid4, UDM10
| Model | Vimeo90K | Vid4 | UDM10 |
|---|---|---|---|
| PP-MSVSR_vimeo90k_x4 |37.54/0.9499|28.13/0.8604|40.06/0.9699|
## 1.4 Model Download
| Method | Dataset | Download Link |
......@@ -94,7 +157,9 @@ The metrics are PSNR / SSIM.
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams)
| PP-MSVSR_reds_x4 | REDS | [PP-MSVSR_reds_x4](https://paddlegan.bj.bcebos.com/models/PP-MSVSR_reds_x4.pdparams)
| PP-MSVSR-L_reds_x4 | REDS | [PP-MSVSR-L_reds_x4](https://paddlegan.bj.bcebos.com/models/PP-MSVSR-L_reds_x4.pdparams)
| PP-MSVSR_vimeo90k_x4 | Vimeo90K | [PP-MSVSR_vimeo90k_x4](https://paddlegan.bj.bcebos.com/models/PP-MSVSR_vimeo90k_x4.pdparams)
......@@ -133,3 +198,10 @@ The metrics are PSNR / SSIM.
year = {2021}
}
```
- 4. [PP-MSVSR: Multi-Stage Video Super-Resolution]()
```
@article{
}
```
......@@ -3,16 +3,22 @@
## 1.1 原理介绍
视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf),[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf).
视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf)[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf)[IconVSR](https://arxiv.org/pdf/2012.02181.pdf)[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf)和PP-MSVSR。
[EDVR](https://arxiv.org/pdf/1905.02716.pdf)模型在NTIRE19视频恢复和增强挑战赛的四个赛道中都赢得了冠军,并以巨大的优势超过了第二名。视频超分的主要难点在于(1)如何在给定大运动的情况下对齐多个帧;(2)如何有效地融合具有不同运动和模糊的不同帧。首先,为了处理大的运动,EDVR模型设计了一个金字塔级联的可变形(PCD)对齐模块,在该模块中,从粗到精的可变形卷积被使用来进行特征级的帧对齐。其次,EDVR使用了时空注意力(TSA)融合模块,该模块在时间和空间上同时应用注意力机制,以强调后续恢复的重要特征。
[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf)在VSR的指导下重新考虑了四个基本模块(即传播、对齐、聚合和上采样)的一些最重要的组件。 通过添加一些小设计,重用一些现有组件,得到了简洁的 BasicVSR。与许多最先进的算法相比,BasicVSR在速度和恢复质量方面实现了有吸引力的改进。 同时,通过添加信息重新填充机制和耦合传播方案以促进信息聚合,BasicVSR 可以扩展为 [IconVSR](https://arxiv.org/pdf/2012.02181.pdf),IconVSR可以作为未来 VSR 方法的强大基线 .
[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf)通过提出二阶网格传播和导流可变形对齐来重新设计BasicVSR。通过增强传播和对齐来增强循环框架,BasicVSR++可以更有效地利用未对齐视频帧的时空信息。 在类似的计算约束下,新组件可提高性能。特别是,BasicVSR++ 以相似的参数数量在 PSNR 方面比 BasicVSR 高0.82dB。BasicVSR++ 在NTIRE2021的视频超分辨率和压缩视频增强挑战赛中获得三名冠军和一名亚军。
PP-MSVSR是一种多阶段视频超分深度架构,具有局部融合模块、辅助损失和细化对齐模块,以逐步细化增强结果。具体来说,在第一阶段设计了局部融合模块,在特征传播之前进行局部特征融合, 以加强特征传播中跨帧特征的融合。在第二阶段中引入了一个辅助损失,使传播模块获得的特征保留了更多与HR空间相关的信息。在第三阶段中引入了一个细化的对齐模块,以充分利用前一阶段传播模块的特征信息。大量实验证实,PP-MSVSR在Vid4数据集性能优异,仅使用 1.45M 参数PSNR指标即可达到28.13dB。
## 1.2 如何使用
### 1.2.1 数据准备
这里提供4个视频超分辨率常用数据集,REDS,Vimeo90K,Vid4,UDM10。其中REDS和vimeo90k数据集包括训练集和测试集,Vid4和UDM10为测试数据集。将需要的数据集下载解压后放到``PaddleGAN/data``文件夹下 。
REDS([数据下载](https://seungjunnah.github.io/Datasets/reds.html))数据集是NTIRE19公司最新提出的高质量(720p)视频数据集,其由240个训练片段、30个验证片段和30个测试片段组成(每个片段有100个连续帧)。由于测试数据集不可用,这里在训练集选择了四个具有代表性的片段(分别为'000', '011', '015', '020',它们具有不同的场景和动作)作为测试集,用REDS4表示。剩下的训练和验证片段被重新分组为训练数据集(总共266个片段)。
处理后的数据集 REDS 的组成形式如下:
......@@ -31,6 +37,49 @@
...
```
Vimeo90K([数据下载](http://toflow.csail.mit.edu/))数据集是Tianfan Xue等人构建的一个用于视频超分、视频降噪、视频去伪影、视频插帧的数据集。Vimeo90K是大规模、高质量的视频数据集,包含从vimeo.com下载的 89,800 个视频剪辑,涵盖了大量场景和动作。
处理后的数据集 Vimeo90K 的组成形式如下:
```
PaddleGAN
├── data
├── Vimeo90K
├── vimeo_septuplet
| |──sequences
| └──sep_trainlist.txt
├── vimeo_septuplet_BD_matlabLRx4
| └──sequences
└── vimeo_super_resolution_test
|──low_resolution
|──target
└──sep_testlist.txt
...
```
Vid4([数据下载](https://paddlegan.bj.bcebos.com/datasets/Vid4.zip))数据集是常用的视频超分验证数据集,包含4个视频段。
处理后的数据集 Vid4 的组成形式如下:
```
PaddleGAN
├── data
├── Vid4
├── BDx4
└── GT
...
```
UDM10([数据下载](https://paddlegan.bj.bcebos.com/datasets/udm10_paddle.tar))数据集是常用的视频超分验证数据集,包含10个视频段。
处理后的数据集 UDM10 的组成形式如下:
```
PaddleGAN
├── data
├── udm10
├── BDx4
└── GT
...
```
### 1.2.2 训练/测试
EDVR模型根据模型中间通道数分为EDVR_L(128通道)和EDVR_M(64通道)两种模型。下面以EDVR_M模型为例介绍模型训练与测试。
......@@ -59,23 +108,37 @@
python tools/main.py --config-file configs/edvr_m_w_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
训练或测试其他视频超分模型,可以在``PaddleGAN/configs``文件夹下找到对应模型的配置文件,将命令中的配置文件改成该视频超分模型的配置文件即可。
## 1.3 实验结果展示
实验数值结果是在 RGB 通道上进行评估。
度量指标为 PSNR / SSIM.
REDS的测试数据集REDS4上的超分性能对比
| 模型| 参数量(M) | 计算量(G) | REDS4 |
|---|---|---|---|
| EDVR_M_wo_tsa_SRx4 | 3.00 | 223 | 30.4429 / 0.8684 |
| EDVR_M_w_tsa_SRx4 | 3.30 | 232 | 30.5169 / 0.8699 |
| EDVR_L_wo_tsa_SRx4 | 19.42 | 974 | 30.8649 / 0.8761 |
| EDVR_L_w_tsa_SRx4 | 20.63 | 1010 | 30.9336 / 0.8773 |
| BasicVSR_x4 | 6.29 | 374 | 31.4325 / 0.8913 |
| IconVSR_x4 | 8.69 | 516 | 31.6882 / 0.8950 |
| BasicVSR++_x4 | 7.32 | 406 | 32.4018 / 0.9071 |
| PP-MSVSR_reds_x4 | 1.45 | 111 | 31.2535 / 0.8884 |
| PP-MSVSR-L_reds_x4 | 7.42 | 543 | 32.5321 / 0.9083 |
REDS的测试数据集REDS4上的去模糊性能对比
| 模型 | REDS4 |
|---|---|
| EDVR_M_wo_tsa_SRx4 | 30.4429 / 0.8684 |
| EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 |
| EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 |
| EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 |
| EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 |
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
| BasicVSR++_x4 | 32.4018 / 0.9071 |
Vimeo90K,Vid4,UDM10测试数据集上超分性能对比
| 模型 | Vimeo90K | Vid4 | UDM10 |
|---|---|---|---|
| PP-MSVSR_vimeo90k_x4 |37.54/0.9499|28.13/0.8604|40.06/0.9699|
## 1.4 模型下载
| 模型 | 数据集 | 下载地址 |
......@@ -89,8 +152,9 @@
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams)
| PP-MSVSR_reds_x4 | REDS | [PP-MSVSR_reds_x4](https://paddlegan.bj.bcebos.com/models/PP-MSVSR_reds_x4.pdparams)
| PP-MSVSR-L_reds_x4 | REDS | [PP-MSVSR-L_reds_x4](https://paddlegan.bj.bcebos.com/models/PP-MSVSR-L_reds_x4.pdparams)
| PP-MSVSR_vimeo90k_x4 | Vimeo90K | [PP-MSVSR_vimeo90k_x4](https://paddlegan.bj.bcebos.com/models/PP-MSVSR_vimeo90k_x4.pdparams)
# 参考文献
......@@ -125,3 +189,10 @@
year = {2021}
}
```
- 4. [PP-MSVSR: Multi-Stage Video Super-Resolution]()
```
@article{
}
```
......@@ -212,8 +212,8 @@ def calculate_ssim(img1,
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')
img1 = img1.copy().astype('float32')[..., ::-1]
img2 = img2.copy().astype('float32')[..., ::-1]
img1 = img1.copy().astype('float32')
img2 = img2.copy().astype('float32')
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
......
......@@ -33,3 +33,4 @@ from .lapstyle_model import LapStyleDraModel, LapStyleRevFirstModel, LapStyleRev
from .basicvsr_model import BasicVSRModel
from .mpr_model import MPRModel
from .photopen_model import PhotoPenModel
from .msvsr_model import MultiStageVSRModel
......@@ -36,4 +36,5 @@ from .iconvsr import IconVSR
from .gpen import GPEN
from .pan import PAN
from .generater_photopen import SPADEGenerator
from .basicvsr_plus_plus import BasicVSRPlusPlus
\ No newline at end of file
from .basicvsr_plus_plus import BasicVSRPlusPlus
from .msvsr import MSVSR
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .builder import MODELS
from .sr_model import BaseSRModel
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet
from .generators.msvsr import ModifiedSPyNet
from ..modules.init import reset_parameters
from ..utils.visual import tensor2img
@MODELS.register()
class MultiStageVSRModel(BaseSRModel):
"""PP-MSVSR Model.
Paper:
PP-MSVSR: Multi-Stage Video Super-Resolution, 2021
"""
def __init__(self, generator, fix_iter, pixel_criterion=None):
"""Initialize the PP-MSVSR class.
Args:
generator (dict): config of generator.
fix_iter (dict): config of fix_iter.
pixel_criterion (dict): config of pixel criterion.
"""
super(MultiStageVSRModel, self).__init__(generator, pixel_criterion)
self.fix_iter = fix_iter
self.current_iter = 1
self.flag = True
init_basicvsr_weight(self.nets['generator'])
if not self.fix_iter:
print('init train all parameters!!!')
for name, param in self.nets['generator'].named_parameters():
param.trainable = True
if 'spynet' in name:
param.optimize_attr['learning_rate'] = 0.25
def setup_input(self, input):
self.lq = paddle.to_tensor(input['lq'])
self.visual_items['lq'] = self.lq[:, 0, :, :, :]
if 'gt' in input:
self.gt = paddle.to_tensor(input['gt'])
self.visual_items['gt'] = self.gt[:, 0, :, :, :]
self.image_paths = input['lq_path']
def train_iter(self, optims=None):
optims['optim'].clear_grad()
if self.fix_iter:
if self.current_iter == 1:
print('Train MSVSR with fixed spynet for', self.fix_iter,
'iters.')
for name, param in self.nets['generator'].named_parameters():
if 'spynet' in name:
param.trainable = False
elif self.current_iter >= self.fix_iter + 1 and self.flag:
print('Train all the parameters.')
for name, param in self.nets['generator'].named_parameters():
param.trainable = True
if 'spynet' in name:
param.optimize_attr['learning_rate'] = 0.25
self.flag = False
for net in self.nets.values():
net.find_unused_parameters = False
output = self.nets['generator'](self.lq)
if isinstance(output, (list, tuple)):
out_stage2, output = output
loss_pix_stage2 = self.pixel_criterion(out_stage2, self.gt)
self.losses['loss_pix_stage2'] = loss_pix_stage2
self.visual_items['output'] = output[:, 0, :, :, :]
# pixel loss
loss_pix = self.pixel_criterion(output, self.gt)
self.losses['loss_pix'] = loss_pix
self.loss = sum(_value for _key, _value in self.losses.items()
if 'loss_pix' in _key)
self.losses['loss'] = self.loss
self.loss.backward()
optims['optim'].step()
self.current_iter += 1
def test_iter(self, metrics=None):
self.gt = self.gt.cpu()
self.nets['generator'].eval()
with paddle.no_grad():
output = self.nets['generator'](self.lq)
if isinstance(output, (list, tuple)):
out_stage1, output = output
self.nets['generator'].train()
out_img = []
gt_img = []
_, t, _, _, _ = self.gt.shape
for i in range(t):
out_tensor = output[0, i]
gt_tensor = self.gt[0, i]
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img, is_seq=True)
def init_basicvsr_weight(net):
for m in net.children():
if hasattr(m,
'weight') and not isinstance(m,
(nn.BatchNorm, nn.BatchNorm2D)):
reset_parameters(m)
continue
if (not isinstance(
m,
(ResidualBlockNoBN, PixelShufflePack, SPyNet, ModifiedSPyNet))):
init_basicvsr_weight(m)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册