未验证 提交 c12e50aa 编写于 作者: W wangna11BD 提交者: GitHub

add EDVR128 (#308)

上级 94adf2c4
total_iters: 600000
output_dir: output_dir
checkpoints_dir: checkpoints
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: EDVRModel
tsa_iter: 50000
generator:
name: EDVRNet
in_nf: 3
out_nf: 3
scale_factor: 4
nf: 128
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 40
center: 2
predeblur: False
HR_in: False
w_TSA: True
pixel_criterion:
name: CharbonnierLoss
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_sharp_bicubic/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 4
fix_random_seed: 10
num_workers: 3
batch_size: 4 # 8GUPs
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 4
fix_random_seed: 10
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 4e-4
periods: [50000, 100000, 150000, 150000, 150000]
restart_weights: [1, 0.5, 0.5, 0.5, 0.5]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
total_iters: 600000
output_dir: output_dir
checkpoints_dir: checkpoints
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: EDVRModel
tsa_iter: 0
generator:
name: EDVRNet
in_nf: 3
out_nf: 3
scale_factor: 4
nf: 128
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 40
center: 2
predeblur: False
HR_in: False
w_TSA: False
pixel_criterion:
name: CharbonnierLoss
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_sharp_bicubic/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 4
fix_random_seed: 10
num_workers: 3
batch_size: 4 # 8GUPs
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 4
fix_random_seed: 10
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 4e-4
periods: [150000, 150000, 150000, 150000]
restart_weights: [1, 0.5, 0.5, 0.5]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
...@@ -22,7 +22,6 @@ model: ...@@ -22,7 +22,6 @@ model:
predeblur: False predeblur: False
HR_in: False HR_in: False
w_TSA: True w_TSA: True
TSA_only: False
pixel_criterion: pixel_criterion:
name: CharbonnierLoss name: CharbonnierLoss
...@@ -43,7 +42,7 @@ dataset: ...@@ -43,7 +42,7 @@ dataset:
scale: 4 scale: 4
fix_random_seed: 10 fix_random_seed: 10
num_workers: 3 num_workers: 3
batch_size: 4 batch_size: 4 # 8GUPs
test: test:
......
...@@ -22,7 +22,6 @@ model: ...@@ -22,7 +22,6 @@ model:
predeblur: False predeblur: False
HR_in: False HR_in: False
w_TSA: False w_TSA: False
TSA_only: False
pixel_criterion: pixel_criterion:
name: CharbonnierLoss name: CharbonnierLoss
...@@ -43,7 +42,7 @@ dataset: ...@@ -43,7 +42,7 @@ dataset:
scale: 4 scale: 4
fix_random_seed: 10 fix_random_seed: 10
num_workers: 3 num_workers: 3
batch_size: 4 batch_size: 4 # 8GUPs
test: test:
......
...@@ -33,18 +33,37 @@ ...@@ -33,18 +33,37 @@
### 1.2.2 Train/Test ### 1.2.2 Train/Test
The command to train and test edvr model with the processed EDVR is as follows: According to the number of channels, EDVR are divided into EDVR_L(128 channels) and EDVR_M (64 channels). Then, taking EDVR_M as an example, the model training and testing are introduced.
The train of EDVR is generally divided into two stages. First, train EDVR without TSA module.
The command to train and test edvr without TSA module is as follows:
Train a model:
```
python -u tools/main.py --config-file configs/edvr_m_wo_tsa.yaml
```
Test the model:
```
python tools/main.py --config-file configs/edvr_m_wo_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT_WITHOUT_TSA}
```
Then the weight of EDVR without TSA module is used as the initialization of edvr model to train the complete edvr model.
The command to train and test edvr is as follows:
Train a model: Train a model:
``` ```
python -u tools/main.py --config-file configs/edvr.yaml python -u tools/main.py --config-file configs/edvr_m_w_tsa.yaml --load ${PATH_OF_WEIGHT_WITHOUT_TSA}
``` ```
Test the model: Test the model:
``` ```
python tools/main.py --config-file configs/edvr.yaml --evaluate-only --load ${PATH_OF_WEIGHT} python tools/main.py --config-file configs/edvr_m_w_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
``` ```
## 1.3 Results ## 1.3 Results
The experimental results are evaluated on RGB channel. The experimental results are evaluated on RGB channel.
...@@ -55,6 +74,7 @@ The metrics are PSNR / SSIM. ...@@ -55,6 +74,7 @@ The metrics are PSNR / SSIM.
| EDVR_M_wo_tsa_SRx4 | 30.4429 / 0.8684 | | EDVR_M_wo_tsa_SRx4 | 30.4429 / 0.8684 |
| EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 | | EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 |
| EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 | | EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 |
| EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 |
## 1.4 Model Download ## 1.4 Model Download
...@@ -63,6 +83,7 @@ The metrics are PSNR / SSIM. ...@@ -63,6 +83,7 @@ The metrics are PSNR / SSIM.
| EDVR_M_wo_tsa_SRx4 | REDS | [EDVR_M_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_wo_tsa_SRx4.pdparams) | EDVR_M_wo_tsa_SRx4 | REDS | [EDVR_M_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_wo_tsa_SRx4.pdparams)
| EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams) | EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams)
| EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams) | EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams)
| EDVR_L_w_tsa_SRx4 | REDS | [EDVR_L_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams)
...@@ -81,4 +102,3 @@ The metrics are PSNR / SSIM. ...@@ -81,4 +102,3 @@ The metrics are PSNR / SSIM.
year = {2019} year = {2019}
} }
``` ```
...@@ -33,18 +33,33 @@ ...@@ -33,18 +33,33 @@
### 1.2.2 训练/测试 ### 1.2.2 训练/测试
使用处理后的REDS数据集训练与测试EDVR模型命令如下: EDVR模型根据模型中间通道数分为EDVR_L(128通道)和EDVR_M(64通道)两种模型。下面以EDVR_M模型为例介绍模型训练与测试。
EDVR模型训练一般分两个阶段训练,先不带TSA模块训练,训练与测试命令如下:
训练模型:
```
python -u tools/main.py --config-file configs/edvr_m_wo_tsa.yaml
```
测试模型:
```
python tools/main.py --config-file configs/edvr_m_wo_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT_WITHOUT_TSA}
```
然后用保存的不带TSA模块的EDVR权重作为EDVR模型的初始化,训练完整的EDVR模型,训练与测试命令如下:
训练模型: 训练模型:
``` ```
python -u tools/main.py --config-file configs/edvr.yaml python -u tools/main.py --config-file configs/edvr_m_w_tsa.yaml --load ${PATH_OF_WEIGHT_WITHOUT_TSA}
``` ```
测试模型: 测试模型:
``` ```
python tools/main.py --config-file configs/edvr.yaml --evaluate-only --load ${PATH_OF_WEIGHT} python tools/main.py --config-file configs/edvr_m_w_tsa.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
``` ```
## 1.3 实验结果展示 ## 1.3 实验结果展示
实验数值结果是在 RGB 通道上进行评估。 实验数值结果是在 RGB 通道上进行评估。
...@@ -55,6 +70,7 @@ ...@@ -55,6 +70,7 @@
| EDVR_M_wo_tsa_SRx4 | 30.4429 / 0.8684 | | EDVR_M_wo_tsa_SRx4 | 30.4429 / 0.8684 |
| EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 | | EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 |
| EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 | | EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 |
| EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 |
## 1.4 模型下载 ## 1.4 模型下载
...@@ -63,6 +79,7 @@ ...@@ -63,6 +79,7 @@
| EDVR_M_wo_tsa_SRx4 | REDS | [EDVR_M_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_wo_tsa_SRx4.pdparams) | EDVR_M_wo_tsa_SRx4 | REDS | [EDVR_M_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_wo_tsa_SRx4.pdparams)
| EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams) | EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams)
| EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams) | EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams)
| EDVR_L_w_tsa_SRx4 | REDS | [EDVR_L_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams)
...@@ -81,4 +98,3 @@ ...@@ -81,4 +98,3 @@
year = {2019} year = {2019}
} }
``` ```
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
import paddle import paddle
import numpy as np import numpy as np
import scipy.io as scio
import paddle.nn as nn import paddle.nn as nn
from paddle.nn import initializer
from ...modules.init import kaiming_normal_, constant_ from ...modules.init import kaiming_normal_, constant_
from ...modules.dcn import DeformableConv_dygraph from ...modules.dcn import DeformableConv_dygraph
...@@ -63,11 +61,8 @@ class ResidualBlockNoBN(nn.Layer): ...@@ -63,11 +61,8 @@ class ResidualBlockNoBN(nn.Layer):
|________________| |________________|
Args: Args:
num_feat (int): Channel number of intermediate features. nf (int): Channel number of intermediate features.
Default: 64. Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
""" """
def __init__(self, nf=64): def __init__(self, nf=64):
super(ResidualBlockNoBN, self).__init__() super(ResidualBlockNoBN, self).__init__()
...@@ -612,8 +607,7 @@ class EDVRNet(nn.Layer): ...@@ -612,8 +607,7 @@ class EDVRNet(nn.Layer):
center=None, center=None,
predeblur=False, predeblur=False,
HR_in=False, HR_in=False,
w_TSA=True, w_TSA=True):
TSA_only=False):
super(EDVRNet, self).__init__() super(EDVRNet, self).__init__()
self.in_nf = in_nf self.in_nf = in_nf
self.out_nf = out_nf self.out_nf = out_nf
...@@ -637,23 +631,6 @@ class EDVRNet(nn.Layer): ...@@ -637,23 +631,6 @@ class EDVRNet(nn.Layer):
out_channels=self.nf, out_channels=self.nf,
kernel_size=1, kernel_size=1,
stride=1) stride=1)
else:
if self.HR_in:
self.conv_first_1 = nn.Conv2D(in_channels=self.in_nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.conv_first_2 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
self.conv_first_3 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
else: else:
self.conv_first = nn.Conv2D(in_channels=self.in_nf, self.conv_first = nn.Conv2D(in_channels=self.in_nf,
out_channels=self.nf, out_channels=self.nf,
...@@ -711,16 +688,16 @@ class EDVRNet(nn.Layer): ...@@ -711,16 +688,16 @@ class EDVRNet(nn.Layer):
padding=1) padding=1)
self.pixel_shuffle = nn.PixelShuffle(2) self.pixel_shuffle = nn.PixelShuffle(2)
self.upconv2 = nn.Conv2D(in_channels=self.nf, self.upconv2 = nn.Conv2D(in_channels=self.nf,
out_channels=4 * self.nf, out_channels=4 * 64,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
self.HRconv = nn.Conv2D(in_channels=self.nf, self.HRconv = nn.Conv2D(in_channels=64,
out_channels=self.nf, out_channels=64,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
self.conv_last = nn.Conv2D(in_channels=self.nf, self.conv_last = nn.Conv2D(in_channels=64,
out_channels=self.out_nf, out_channels=self.out_nf,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -746,16 +723,6 @@ class EDVRNet(nn.Layer): ...@@ -746,16 +723,6 @@ class EDVRNet(nn.Layer):
L1_fea = self.cov_1(L1_fea) L1_fea = self.cov_1(L1_fea)
if self.HR_in: if self.HR_in:
H, W = H // self.scale_factor, W // self.scale_factor H, W = H // self.scale_factor, W // self.scale_factor
else:
if self.HR_in:
L1_fea = self.conv_first_1(L1_fea)
L1_fea = self.Leaky_relu(L1_fea)
L1_fea = self.conv_first_2(L1_fea)
L1_fea = self.Leaky_relu(L1_fea)
L1_fea = self.conv_first_3(L1_fea)
L1_fea = self.Leaky_relu(L1_fea)
H = H // self.scale_factor
W = W // self.scale_factor
else: else:
L1_fea = self.conv_first(L1_fea) L1_fea = self.conv_first(L1_fea)
L1_fea = self.Leaky_relu(L1_fea) L1_fea = self.Leaky_relu(L1_fea)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册