From c12e50aa0ec94649b8834f95104d2b7377808258 Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Wed, 12 May 2021 20:34:38 +0800 Subject: [PATCH] add EDVR128 (#308) --- configs/edvr_l_w_tsa.yaml | 99 +++++++++++++++++++ configs/edvr_l_wo_tsa.yaml | 99 +++++++++++++++++++ configs/{edvr.yaml => edvr_m_w_tsa.yaml} | 3 +- .../{edvr_wo_tsa.yaml => edvr_m_wo_tsa.yaml} | 3 +- .../en_US/tutorials/video_super_resolution.md | 32 ++++-- .../zh_CN/tutorials/video_super_resolution.md | 30 ++++-- ppgan/models/generators/edvr.py | 59 +++-------- 7 files changed, 262 insertions(+), 63 deletions(-) create mode 100644 configs/edvr_l_w_tsa.yaml create mode 100644 configs/edvr_l_wo_tsa.yaml rename configs/{edvr.yaml => edvr_m_w_tsa.yaml} (97%) rename configs/{edvr_wo_tsa.yaml => edvr_m_wo_tsa.yaml} (97%) diff --git a/configs/edvr_l_w_tsa.yaml b/configs/edvr_l_w_tsa.yaml new file mode 100644 index 0000000..97ce169 --- /dev/null +++ b/configs/edvr_l_w_tsa.yaml @@ -0,0 +1,99 @@ +total_iters: 600000 +output_dir: output_dir +checkpoints_dir: checkpoints +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: EDVRModel + tsa_iter: 50000 + generator: + name: EDVRNet + in_nf: 3 + out_nf: 3 + scale_factor: 4 + nf: 128 + nframes: 5 + groups: 8 + front_RBs: 5 + back_RBs: 40 + center: 2 + predeblur: False + HR_in: False + w_TSA: True + pixel_criterion: + name: CharbonnierLoss + +dataset: + train: + name: REDSDataset + mode: train + gt_folder: data/REDS/train_sharp/X4 + lq_folder: data/REDS/train_sharp_bicubic/X4 + img_format: png + crop_size: 256 + interval_list: [1] + random_reverse: False + number_frames: 5 + use_flip: True + use_rot: True + buf_size: 1024 + scale: 4 + fix_random_seed: 10 + num_workers: 3 + batch_size: 4 # 8GUPs + + + test: + name: REDSDataset + mode: test + gt_folder: data/REDS/REDS4_test_sharp/X4 + lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 + img_format: png + interval_list: [1] + random_reverse: False + number_frames: 5 + batch_size: 1 + use_flip: False + use_rot: False + buf_size: 1024 + scale: 4 + fix_random_seed: 10 + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 4e-4 + periods: [50000, 100000, 150000, 150000, 150000] + restart_weights: [1, 0.5, 0.5, 0.5, 0.5] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 0 + test_y_channel: False + ssim: + name: SSIM + crop_border: 0 + test_y_channel: False + +log_config: + interval: 10 + visiual_interval: 5000 + +snapshot_config: + interval: 5000 diff --git a/configs/edvr_l_wo_tsa.yaml b/configs/edvr_l_wo_tsa.yaml new file mode 100644 index 0000000..facbed5 --- /dev/null +++ b/configs/edvr_l_wo_tsa.yaml @@ -0,0 +1,99 @@ +total_iters: 600000 +output_dir: output_dir +checkpoints_dir: checkpoints +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: EDVRModel + tsa_iter: 0 + generator: + name: EDVRNet + in_nf: 3 + out_nf: 3 + scale_factor: 4 + nf: 128 + nframes: 5 + groups: 8 + front_RBs: 5 + back_RBs: 40 + center: 2 + predeblur: False + HR_in: False + w_TSA: False + pixel_criterion: + name: CharbonnierLoss + +dataset: + train: + name: REDSDataset + mode: train + gt_folder: data/REDS/train_sharp/X4 + lq_folder: data/REDS/train_sharp_bicubic/X4 + img_format: png + crop_size: 256 + interval_list: [1] + random_reverse: False + number_frames: 5 + use_flip: True + use_rot: True + buf_size: 1024 + scale: 4 + fix_random_seed: 10 + num_workers: 3 + batch_size: 4 # 8GUPs + + + test: + name: REDSDataset + mode: test + gt_folder: data/REDS/REDS4_test_sharp/X4 + lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 + img_format: png + interval_list: [1] + random_reverse: False + number_frames: 5 + batch_size: 1 + use_flip: False + use_rot: False + buf_size: 1024 + scale: 4 + fix_random_seed: 10 + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 4e-4 + periods: [150000, 150000, 150000, 150000] + restart_weights: [1, 0.5, 0.5, 0.5] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 0 + test_y_channel: False + ssim: + name: SSIM + crop_border: 0 + test_y_channel: False + +log_config: + interval: 10 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/configs/edvr.yaml b/configs/edvr_m_w_tsa.yaml similarity index 97% rename from configs/edvr.yaml rename to configs/edvr_m_w_tsa.yaml index 8aa206a..79a9e4c 100644 --- a/configs/edvr.yaml +++ b/configs/edvr_m_w_tsa.yaml @@ -22,7 +22,6 @@ model: predeblur: False HR_in: False w_TSA: True - TSA_only: False pixel_criterion: name: CharbonnierLoss @@ -43,7 +42,7 @@ dataset: scale: 4 fix_random_seed: 10 num_workers: 3 - batch_size: 4 + batch_size: 4 # 8GUPs test: diff --git a/configs/edvr_wo_tsa.yaml b/configs/edvr_m_wo_tsa.yaml similarity index 97% rename from configs/edvr_wo_tsa.yaml rename to configs/edvr_m_wo_tsa.yaml index 776da6a..2891fdd 100644 --- a/configs/edvr_wo_tsa.yaml +++ b/configs/edvr_m_wo_tsa.yaml @@ -22,7 +22,6 @@ model: predeblur: False HR_in: False w_TSA: False - TSA_only: False pixel_criterion: name: CharbonnierLoss @@ -43,7 +42,7 @@ dataset: scale: 4 fix_random_seed: 10 num_workers: 3 - batch_size: 4 + batch_size: 4 # 8GUPs test: diff --git a/docs/en_US/tutorials/video_super_resolution.md b/docs/en_US/tutorials/video_super_resolution.md index 688f0e2..5821cbb 100644 --- a/docs/en_US/tutorials/video_super_resolution.md +++ b/docs/en_US/tutorials/video_super_resolution.md @@ -13,7 +13,7 @@ ### 1.2.1 Prepare Datasets - REDS([download](https://seungjunnah.github.io/Datasets/reds.html))is a newly proposed high-quality (720p) video dataset in the NTIRE19 Competition. REDS consists of 240 training clips, 30 validation clips and 30 testing clips (each with 100 consecutive frames). Since the test ground truth is not available, we select four representative clips (they are '000', '011', '015', '020', with diverse scenes and motions) as our test set, denoted by REDS4. The remaining training and validation clips are re-grouped as our training dataset (a total of 266 clips). + REDS([download](https://seungjunnah.github.io/Datasets/reds.html))is a newly proposed high-quality (720p) video dataset in the NTIRE19 Competition. REDS consists of 240 training clips, 30 validation clips and 30 testing clips (each with 100 consecutive frames). Since the test ground truth is not available, we select four representative clips (they are '000', '011', '015', '020', with diverse scenes and motions) as our test set, denoted by REDS4. The remaining training and validation clips are re-grouped as our training dataset (a total of 266 clips). The structure of the processed REDS is as follows: ``` @@ -33,28 +33,48 @@ ### 1.2.2 Train/Test - The command to train and test edvr model with the processed EDVR is as follows: + According to the number of channels, EDVR are divided into EDVR_L(128 channels) and EDVR_M (64 channels). Then, taking EDVR_M as an example, the model training and testing are introduced. + + The train of EDVR is generally divided into two stages. First, train EDVR without TSA module. + + The command to train and test edvr without TSA module is as follows: + + Train a model: + ``` + python -u tools/main.py --config-file configs/edvr_m_wo_tsa.yaml + ``` + + Test the model: + ``` + python tools/main.py --config-file configs/edvr_m_wo_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT_WITHOUT_TSA} + ``` + + Then the weight of EDVR without TSA module is used as the initialization of edvr model to train the complete edvr model. + + The command to train and test edvr is as follows: Train a model: ``` - python -u tools/main.py --config-file configs/edvr.yaml + python -u tools/main.py --config-file configs/edvr_m_w_tsa.yaml --load ${PATH_OF_WEIGHT_WITHOUT_TSA} ``` Test the model: ``` - python tools/main.py --config-file configs/edvr.yaml --evaluate-only --load ${PATH_OF_WEIGHT} + python tools/main.py --config-file configs/edvr_m_w_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT} ``` + ## 1.3 Results The experimental results are evaluated on RGB channel. The metrics are PSNR / SSIM. -| Method | REDS4 | +| Method | REDS4 | |---|---| | EDVR_M_wo_tsa_SRx4 | 30.4429 / 0.8684 | | EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 | | EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 | +| EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 | ## 1.4 Model Download @@ -63,6 +83,7 @@ The metrics are PSNR / SSIM. | EDVR_M_wo_tsa_SRx4 | REDS | [EDVR_M_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_wo_tsa_SRx4.pdparams) | EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams) | EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams) +| EDVR_L_w_tsa_SRx4 | REDS | [EDVR_L_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams) @@ -81,4 +102,3 @@ The metrics are PSNR / SSIM. year = {2019} } ``` - diff --git a/docs/zh_CN/tutorials/video_super_resolution.md b/docs/zh_CN/tutorials/video_super_resolution.md index c31f4e0..ab69bf6 100644 --- a/docs/zh_CN/tutorials/video_super_resolution.md +++ b/docs/zh_CN/tutorials/video_super_resolution.md @@ -4,12 +4,12 @@ ## 1.1 原理介绍 视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf). - + [EDVR](https://arxiv.org/pdf/1905.02716.pdf)模型在NTIRE19视频恢复和增强挑战赛的四个赛道中都赢得了冠军,并以巨大的优势超过了第二名。视频超分的主要难点在于(1)如何在给定大运动的情况下对齐多个帧;(2)如何有效地融合具有不同运动和模糊的不同帧。首先,为了处理大的运动,EDVR模型设计了一个金字塔级联的可变形(PCD)对齐模块,在该模块中,从粗到精的可变形卷积被使用来进行特征级的帧对齐。其次,EDVR使用了时空注意力(TSA)融合模块,该模块在时间和空间上同时应用注意力机制,以强调后续恢复的重要特征。 -## 1.2 如何使用 +## 1.2 如何使用 ### 1.2.1 数据准备 @@ -33,28 +33,44 @@ ### 1.2.2 训练/测试 - 使用处理后的REDS数据集训练与测试EDVR模型命令如下: + EDVR模型根据模型中间通道数分为EDVR_L(128通道)和EDVR_M(64通道)两种模型。下面以EDVR_M模型为例介绍模型训练与测试。 + + EDVR模型训练一般分两个阶段训练,先不带TSA模块训练,训练与测试命令如下: 训练模型: ``` - python -u tools/main.py --config-file configs/edvr.yaml + python -u tools/main.py --config-file configs/edvr_m_wo_tsa.yaml ``` 测试模型: ``` - python tools/main.py --config-file configs/edvr.yaml --evaluate-only --load ${PATH_OF_WEIGHT} + python tools/main.py --config-file configs/edvr_m_wo_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT_WITHOUT_TSA} + ``` + + 然后用保存的不带TSA模块的EDVR权重作为EDVR模型的初始化,训练完整的EDVR模型,训练与测试命令如下: + + 训练模型: + ``` + python -u tools/main.py --config-file configs/edvr_m_w_tsa.yaml --load ${PATH_OF_WEIGHT_WITHOUT_TSA} ``` + 测试模型: + ``` + python tools/main.py --config-file configs/edvr_m_w_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT} + ``` + + ## 1.3 实验结果展示 实验数值结果是在 RGB 通道上进行评估。 度量指标为 PSNR / SSIM. -| 模型 | REDS4 | +| 模型 | REDS4 | |---|---| | EDVR_M_wo_tsa_SRx4 | 30.4429 / 0.8684 | | EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 | | EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 | +| EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 | ## 1.4 模型下载 @@ -63,6 +79,7 @@ | EDVR_M_wo_tsa_SRx4 | REDS | [EDVR_M_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_wo_tsa_SRx4.pdparams) | EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams) | EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams) +| EDVR_L_w_tsa_SRx4 | REDS | [EDVR_L_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams) @@ -81,4 +98,3 @@ year = {2019} } ``` - diff --git a/ppgan/models/generators/edvr.py b/ppgan/models/generators/edvr.py index 57cb859..edb23a6 100644 --- a/ppgan/models/generators/edvr.py +++ b/ppgan/models/generators/edvr.py @@ -15,10 +15,8 @@ import paddle import numpy as np -import scipy.io as scio import paddle.nn as nn -from paddle.nn import initializer from ...modules.init import kaiming_normal_, constant_ from ...modules.dcn import DeformableConv_dygraph @@ -63,11 +61,8 @@ class ResidualBlockNoBN(nn.Layer): |________________| Args: - num_feat (int): Channel number of intermediate features. + nf (int): Channel number of intermediate features. Default: 64. - res_scale (float): Residual scale. Default: 1. - pytorch_init (bool): If set to True, use pytorch default init, - otherwise, use default_init_weights. Default: False. """ def __init__(self, nf=64): super(ResidualBlockNoBN, self).__init__() @@ -612,8 +607,7 @@ class EDVRNet(nn.Layer): center=None, predeblur=False, HR_in=False, - w_TSA=True, - TSA_only=False): + w_TSA=True): super(EDVRNet, self).__init__() self.in_nf = in_nf self.out_nf = out_nf @@ -638,28 +632,11 @@ class EDVRNet(nn.Layer): kernel_size=1, stride=1) else: - if self.HR_in: - self.conv_first_1 = nn.Conv2D(in_channels=self.in_nf, - out_channels=self.nf, - kernel_size=3, - stride=1, - padding=1) - self.conv_first_2 = nn.Conv2D(in_channels=self.nf, - out_channels=self.nf, - kernel_size=3, - stride=2, - padding=1) - self.conv_first_3 = nn.Conv2D(in_channels=self.nf, - out_channels=self.nf, - kernel_size=3, - stride=2, - padding=1) - else: - self.conv_first = nn.Conv2D(in_channels=self.in_nf, - out_channels=self.nf, - kernel_size=3, - stride=1, - padding=1) + self.conv_first = nn.Conv2D(in_channels=self.in_nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) #feature extraction module self.feature_extractor = MakeMultiBlocks(ResidualBlockNoBN, @@ -711,16 +688,16 @@ class EDVRNet(nn.Layer): padding=1) self.pixel_shuffle = nn.PixelShuffle(2) self.upconv2 = nn.Conv2D(in_channels=self.nf, - out_channels=4 * self.nf, + out_channels=4 * 64, kernel_size=3, stride=1, padding=1) - self.HRconv = nn.Conv2D(in_channels=self.nf, - out_channels=self.nf, + self.HRconv = nn.Conv2D(in_channels=64, + out_channels=64, kernel_size=3, stride=1, padding=1) - self.conv_last = nn.Conv2D(in_channels=self.nf, + self.conv_last = nn.Conv2D(in_channels=64, out_channels=self.out_nf, kernel_size=3, stride=1, @@ -747,18 +724,8 @@ class EDVRNet(nn.Layer): if self.HR_in: H, W = H // self.scale_factor, W // self.scale_factor else: - if self.HR_in: - L1_fea = self.conv_first_1(L1_fea) - L1_fea = self.Leaky_relu(L1_fea) - L1_fea = self.conv_first_2(L1_fea) - L1_fea = self.Leaky_relu(L1_fea) - L1_fea = self.conv_first_3(L1_fea) - L1_fea = self.Leaky_relu(L1_fea) - H = H // self.scale_factor - W = W // self.scale_factor - else: - L1_fea = self.conv_first(L1_fea) - L1_fea = self.Leaky_relu(L1_fea) + L1_fea = self.conv_first(L1_fea) + L1_fea = self.Leaky_relu(L1_fea) # feature extraction and create Pyramid L1_fea = self.feature_extractor(L1_fea) -- GitLab