未验证 提交 97f96b94 编写于 作者: L LielinJiang 提交者: GitHub

release edvr deblur config and pretrained model (#348)

* add edvr blur model
上级 19fe4fbc
total_iters: 600000
output_dir: output_dir
checkpoints_dir: checkpoints
find_unused_parameters: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: EDVRModel
tsa_iter: 50000
generator:
name: EDVRNet
in_nf: 3
out_nf: 3
scale_factor: 1
nf: 128
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 40
center: 2
predeblur: True #False
HR_in: True #False
w_TSA: True
pixel_criterion:
name: CharbonnierLoss
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_blur/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 1
fix_random_seed: 10
num_workers: 6
batch_size: 8
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
lq_folder: data/REDS/REDS4_test_blur/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 1
fix_random_seed: 10
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 4e-4
periods: [50000, 100000, 150000, 150000, 150000]
restart_weights: [1, 0.5, 0.5, 0.5, 0.5]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 10000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 50
visiual_interval: 5000
snapshot_config:
interval: 5000
total_iters: 600000
output_dir: output_dir
checkpoints_dir: checkpoints
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: EDVRModel
tsa_iter: False
generator:
name: EDVRNet
in_nf: 3
out_nf: 3
scale_factor: 1
nf: 128
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 40
center: 2
predeblur: True #False
HR_in: True #False
w_TSA: False
pixel_criterion:
name: CharbonnierLoss
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_blur/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 1
fix_random_seed: 10
num_workers: 6
batch_size: 8
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
lq_folder: data/REDS/REDS4_test_blur/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 1
fix_random_seed: 10
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [150000, 150000, 150000, 150000]
restart_weights: [1, 0.5, 0.5, 0.5]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
...@@ -92,8 +92,8 @@ validate: ...@@ -92,8 +92,8 @@ validate:
test_y_channel: False test_y_channel: False
log_config: log_config:
interval: 10 interval: 100
visiual_interval: 500 visiual_interval: 5000
snapshot_config: snapshot_config:
interval: 5000 interval: 5000
...@@ -75,6 +75,8 @@ The metrics are PSNR / SSIM. ...@@ -75,6 +75,8 @@ The metrics are PSNR / SSIM.
| EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 | | EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 |
| EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 | | EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 |
| EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 | | EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 |
| EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 |
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
## 1.4 Model Download ## 1.4 Model Download
...@@ -84,6 +86,8 @@ The metrics are PSNR / SSIM. ...@@ -84,6 +86,8 @@ The metrics are PSNR / SSIM.
| EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams) | EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams)
| EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams) | EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams)
| EDVR_L_w_tsa_SRx4 | REDS | [EDVR_L_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams) | EDVR_L_w_tsa_SRx4 | REDS | [EDVR_L_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams)
| EDVR_L_wo_tsa_deblur | REDS | [EDVR_L_wo_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_deblur.pdparams)
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
......
...@@ -71,6 +71,8 @@ ...@@ -71,6 +71,8 @@
| EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 | | EDVR_M_w_tsa_SRx4 | 30.5169 / 0.8699 |
| EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 | | EDVR_L_wo_tsa_SRx4 | 30.8649 / 0.8761 |
| EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 | | EDVR_L_w_tsa_SRx4 | 30.9336 / 0.8773 |
| EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 |
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
## 1.4 模型下载 ## 1.4 模型下载
...@@ -80,7 +82,8 @@ ...@@ -80,7 +82,8 @@
| EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams) | EDVR_M_w_tsa_SRx4 | REDS | [EDVR_M_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_M_w_tsa_SRx4.pdparams)
| EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams) | EDVR_L_wo_tsa_SRx4 | REDS | [EDVR_L_wo_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_SRx4.pdparams)
| EDVR_L_w_tsa_SRx4 | REDS | [EDVR_L_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams) | EDVR_L_w_tsa_SRx4 | REDS | [EDVR_L_w_tsa_SRx4](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams)
| EDVR_L_wo_tsa_deblur | REDS | [EDVR_L_wo_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_deblur.pdparams)
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
......
...@@ -17,7 +17,7 @@ import paddle.nn as nn ...@@ -17,7 +17,7 @@ import paddle.nn as nn
from .builder import MODELS from .builder import MODELS
from .sr_model import BaseSRModel from .sr_model import BaseSRModel
from .generators.edvr import ResidualBlockNoBN from .generators.edvr import ResidualBlockNoBN, DCNPack
from ..modules.init import reset_parameters from ..modules.init import reset_parameters
...@@ -77,10 +77,10 @@ class EDVRModel(BaseSRModel): ...@@ -77,10 +77,10 @@ class EDVRModel(BaseSRModel):
def init_edvr_weight(net): def init_edvr_weight(net):
def reset_func(m): def reset_func(m):
if hasattr(m, if hasattr(m, 'weight') and (not isinstance(
'weight') and (not isinstance(m, m, (nn.BatchNorm, nn.BatchNorm2D))) and (
(nn.BatchNorm, nn.BatchNorm2D)) not isinstance(m, ResidualBlockNoBN) and
) and (not isinstance(m, ResidualBlockNoBN)): (not isinstance(m, DCNPack))):
reset_parameters(m) reset_parameters(m)
net.apply(reset_func) net.apply(reset_func)
...@@ -17,7 +17,7 @@ import paddle ...@@ -17,7 +17,7 @@ import paddle
import numpy as np import numpy as np
import paddle.nn as nn import paddle.nn as nn
from ...modules.init import kaiming_normal_, constant_ from ...modules.init import kaiming_normal_, constant_, constant_init
from ...modules.dcn import DeformableConv_dygraph from ...modules.dcn import DeformableConv_dygraph
# from paddle.vision.ops import DeformConv2D #to be compiled # from paddle.vision.ops import DeformConv2D #to be compiled
...@@ -382,6 +382,8 @@ class DCNPack(nn.Layer): ...@@ -382,6 +382,8 @@ class DCNPack(nn.Layer):
deformable_groups=self.deformable_groups) deformable_groups=self.deformable_groups)
# self.dcn = DeformConv2D(in_channels=self.num_filters,out_channels=self.num_filters,kernel_size=self.kernel_size,stride=stride,padding=padding,dilation=dilation,deformable_groups=self.deformable_groups,groups=1) # to be compiled # self.dcn = DeformConv2D(in_channels=self.num_filters,out_channels=self.num_filters,kernel_size=self.kernel_size,stride=stride,padding=padding,dilation=dilation,deformable_groups=self.deformable_groups,groups=1) # to be compiled
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
# init conv offset
constant_init(self.conv_offset_mask, 0., 0.)
def forward(self, fea_and_offset): def forward(self, fea_and_offset):
out = None out = None
...@@ -686,6 +688,7 @@ class EDVRNet(nn.Layer): ...@@ -686,6 +688,7 @@ class EDVRNet(nn.Layer):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
self.pixel_shuffle = nn.PixelShuffle(2) self.pixel_shuffle = nn.PixelShuffle(2)
self.upconv2 = nn.Conv2D(in_channels=self.nf, self.upconv2 = nn.Conv2D(in_channels=self.nf,
out_channels=4 * 64, out_channels=4 * 64,
...@@ -702,10 +705,11 @@ class EDVRNet(nn.Layer): ...@@ -702,10 +705,11 @@ class EDVRNet(nn.Layer):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
self.upsample = nn.Upsample(scale_factor=self.scale_factor, if self.scale_factor == 4:
mode="bilinear", self.upsample = nn.Upsample(scale_factor=self.scale_factor,
align_corners=False, mode="bilinear",
align_mode=0) align_corners=False,
align_mode=0)
def forward(self, x): def forward(self, x):
""" """
...@@ -722,7 +726,7 @@ class EDVRNet(nn.Layer): ...@@ -722,7 +726,7 @@ class EDVRNet(nn.Layer):
L1_fea = self.pre_deblur(L1_fea) L1_fea = self.pre_deblur(L1_fea)
L1_fea = self.cov_1(L1_fea) L1_fea = self.cov_1(L1_fea)
if self.HR_in: if self.HR_in:
H, W = H // self.scale_factor, W // self.scale_factor H, W = H // 4, W // 4
else: else:
L1_fea = self.conv_first(L1_fea) L1_fea = self.conv_first(L1_fea)
L1_fea = self.Leaky_relu(L1_fea) L1_fea = self.Leaky_relu(L1_fea)
...@@ -782,5 +786,6 @@ class EDVRNet(nn.Layer): ...@@ -782,5 +786,6 @@ class EDVRNet(nn.Layer):
base = x_center base = x_center
else: else:
base = self.upsample(x_center) base = self.upsample(x_center)
out += base out += base
return out return out
...@@ -233,7 +233,6 @@ class Encoder(nn.Layer): ...@@ -233,7 +233,6 @@ class Encoder(nn.Layer):
nn.Conv2D(512, 512, (3, 3)), nn.Conv2D(512, 512, (3, 3)),
nn.ReLU() # relu5-4 nn.ReLU() # relu5-4
) )
weight_path = get_path_from_url( weight_path = get_path_from_url(
'https://paddlegan.bj.bcebos.com/models/vgg_normalised.pdparams') 'https://paddlegan.bj.bcebos.com/models/vgg_normalised.pdparams')
vgg_net.set_dict(paddle.load(weight_path)) vgg_net.set_dict(paddle.load(weight_path))
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
...@@ -22,10 +23,10 @@ from paddle.fluid.layers import nn, utils ...@@ -22,10 +23,10 @@ from paddle.fluid.layers import nn, utils
from paddle.nn import Layer from paddle.nn import Layer
from paddle.fluid.initializer import Normal from paddle.fluid.initializer import Normal
from paddle.common_ops_import import * from paddle.common_ops_import import *
from .init import uniform_, constant_
class DeformConv2D(Layer): class DeformConv2D(Layer):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
...@@ -69,8 +70,21 @@ class DeformConv2D(Layer): ...@@ -69,8 +70,21 @@ class DeformConv2D(Layer):
shape=filter_shape, shape=filter_shape,
attr=self._weight_attr, attr=self._weight_attr,
default_initializer=_get_default_param_initializer()) default_initializer=_get_default_param_initializer())
self.bias = self.create_parameter(
attr=self._bias_attr, shape=[self._out_channels], is_bias=True) self.bias = self.create_parameter(attr=self._bias_attr,
shape=[self._out_channels],
is_bias=True)
self.init_weight()
def init_weight(self):
n = self._in_channels
for k in self._kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
uniform_(self.weight, -stdv, stdv)
if hasattr(self, 'bias') and self.bias is not None:
constant_(self.bias, 0.)
def forward(self, x, offset, mask): def forward(self, x, offset, mask):
out = deform_conv2d( out = deform_conv2d(
...@@ -84,7 +98,7 @@ class DeformConv2D(Layer): ...@@ -84,7 +98,7 @@ class DeformConv2D(Layer):
dilation=self._dilation, dilation=self._dilation,
deformable_groups=self._deformable_groups, deformable_groups=self._deformable_groups,
groups=self._groups, groups=self._groups,
) )
return out return out
...@@ -99,7 +113,7 @@ def deform_conv2d(x, ...@@ -99,7 +113,7 @@ def deform_conv2d(x,
deformable_groups=1, deformable_groups=1,
groups=1, groups=1,
name=None): name=None):
stride = utils.convert_to_list(stride, 2, 'stride') stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding') padding = utils.convert_to_list(padding, 2, 'padding')
dilation = utils.convert_to_list(dilation, 2, 'dilation') dilation = utils.convert_to_list(dilation, 2, 'dilation')
...@@ -107,8 +121,9 @@ def deform_conv2d(x, ...@@ -107,8 +121,9 @@ def deform_conv2d(x,
use_deform_conv2d_v1 = True if mask is None else False use_deform_conv2d_v1 = True if mask is None else False
if in_dygraph_mode(): if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'deformable_groups',deformable_groups, attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'im2col_step', 1) 'deformable_groups', deformable_groups, 'groups', groups,
'im2col_step', 1)
if use_deform_conv2d_v1: if use_deform_conv2d_v1:
op_type = 'deformable_conv_v1' op_type = 'deformable_conv_v1'
pre_bias = getattr(core.ops, op_type)(x, offset, weight, *attrs) pre_bias = getattr(core.ops, op_type)(x, offset, weight, *attrs)
...@@ -144,8 +159,14 @@ def deform_conv2d(x, ...@@ -144,8 +159,14 @@ def deform_conv2d(x,
class DeformableConv_dygraph(Layer): class DeformableConv_dygraph(Layer):
def __init__(self,num_filters,filter_size,dilation, def __init__(self,
stride,padding,deformable_groups=1,groups=1): num_filters,
filter_size,
dilation,
stride,
padding,
deformable_groups=1,
groups=1):
super(DeformableConv_dygraph, self).__init__() super(DeformableConv_dygraph, self).__init__()
self.num_filters = num_filters self.num_filters = num_filters
self.filter_size = filter_size self.filter_size = filter_size
...@@ -154,12 +175,18 @@ class DeformableConv_dygraph(Layer): ...@@ -154,12 +175,18 @@ class DeformableConv_dygraph(Layer):
self.padding = padding self.padding = padding
self.deformable_groups = deformable_groups self.deformable_groups = deformable_groups
self.groups = groups self.groups = groups
self.defor_conv = DeformConv2D(in_channels=self.num_filters, out_channels=self.num_filters, self.defor_conv = DeformConv2D(in_channels=self.num_filters,
kernel_size=self.filter_size, stride=self.stride, padding=self.padding, out_channels=self.num_filters,
dilation=self.dilation, deformable_groups=self.deformable_groups, groups=self.groups, weight_attr=None, bias_attr=None) kernel_size=self.filter_size,
stride=self.stride,
padding=self.padding,
def forward(self,*input): dilation=self.dilation,
deformable_groups=self.deformable_groups,
groups=self.groups,
weight_attr=None,
bias_attr=None)
def forward(self, *input):
x = input[0] x = input[0]
offset = input[1] offset = input[1]
mask = input[2] mask = input[2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册