未验证 提交 64181fa9 编写于 作者: W wangna11BD 提交者: GitHub

add edvr model (#208)

* add edvr model

* modifying code formats and comments

* modifying code formats and comments

* modifying code formats and comments

* add notes
Co-authored-by: NLielinJiang <50691816+LielinJiang@users.noreply.github.com>
上级 e72aae53
total_iters: 600000
output_dir: output_dir
checkpoints_dir: checkpoints
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: EDVRModel
tsa_iter: 50000
generator:
name: EDVRNet
in_nf: 3
out_nf: 3
scale_factor: 4
nf: 64
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 10
center: 2
predeblur: False
HR_in: False
w_TSA: True
TSA_only: False
pixel_criterion:
name: CharbonnierLoss
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_sharp_bicubic/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 4
fix_random_seed: 10
num_workers: 3
batch_size: 4
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 4
fix_random_seed: 10
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 4e-4
periods: [50000, 100000, 150000, 150000, 150000]
restart_weights: [1, 1, 1, 1, 1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
total_iters: 600000
output_dir: output_dir
checkpoints_dir: checkpoints
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: EDVRModel
tsa_iter: 0
generator:
name: EDVRNet
in_nf: 3
out_nf: 3
scale_factor: 4
nf: 64
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 10
center: 2
predeblur: False
HR_in: False
w_TSA: False
TSA_only: False
pixel_criterion:
name: CharbonnierLoss
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_sharp_bicubic/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 4
fix_random_seed: 10
num_workers: 3
batch_size: 4
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 4
fix_random_seed: 10
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 4e-4
periods: [150000, 150000, 150000, 150000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
......@@ -21,3 +21,4 @@ from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset
from .wav2lip_dataset import Wav2LipDataset
from .starganv2_dataset import StarGANv2Dataset
from .edvr_dataset import REDSDataset
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import numpy as np
import scipy.io as scio
import cv2
import paddle
from paddle.io import Dataset, DataLoader
from .builder import DATASETS
logger = logging.getLogger(__name__)
@DATASETS.register()
class REDSDataset(Dataset):
"""
REDS dataset for EDVR model
"""
def __init__(self,
mode,
lq_folder,
gt_folder,
img_format="png",
crop_size=256,
interval_list=[1],
random_reverse=False,
number_frames=5,
batch_size=32,
use_flip=False,
use_rot=False,
buf_size=1024,
scale=4,
fix_random_seed=False):
super(REDSDataset, self).__init__()
self.format = img_format
self.mode = mode
self.crop_size = crop_size
self.interval_list = interval_list
self.random_reverse = random_reverse
self.number_frames = number_frames
self.batch_size = batch_size
self.fileroot = lq_folder
self.use_flip = use_flip
self.use_rot = use_rot
self.buf_size = buf_size
self.fix_random_seed = fix_random_seed
if self.mode != 'infer':
self.gtroot = gt_folder
self.scale = scale
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(10)
np.random.seed(10)
self.num_reader_threads = 1
self._init_()
def _init_(self):
logger.info('initialize reader ... ')
print("initialize reader")
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in [
'000', '011', '015', '020'
]): #These four videos are used as val
continue
for frame_name in os.listdir(os.path.join(self.fileroot,
video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + str(frame_idx)
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.mode == 'test':
self.filelist.sort()
print(len(self.filelist))
def __getitem__(self, index):
"""Get training sample
return: lq:[5,3,W,H],
gt:[3,W,H],
lq_path:str
"""
item = self.filelist[index]
img_LQs, img_GT = self.get_sample_data(
item, self.number_frames, self.interval_list, self.random_reverse,
self.gtroot, self.fileroot, self.LR_input, self.crop_size,
self.scale, self.use_flip, self.use_rot, self.mode)
return {'lq': img_LQs, 'gt': img_GT, 'lq_path': self.filelist[index]}
def get_sample_data(self,
item,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
ngb_frames, name_b = self.get_neighbor_frames(frame_name, \
number_frames=number_frames, \
interval_list=interval_list, \
random_reverse=random_reverse)
elif mode == 'test':
ngb_frames, name_b = self.get_test_neighbor_frames(
int(frame_name), number_frames)
else:
raise NotImplementedError('mode {} not implemented'.format(mode))
frame_name = name_b
img_GT = self.read_img(
os.path.join(gtroot, video_name, frame_name + '.png'))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d" % ngb_frm
img = self.read_img(
os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
frame_list = [
v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
for v in frame_list
]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size,
rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [
v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
for v in frame_list
]
img_GT = img_GT[rnd_h:rnd_h + crop_size,
rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
if (mode == 'train') or (mode == 'valid'):
rlt = self.img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, img_GT
def get_neighbor_frames(self,
frame_name,
number_frames,
interval_list,
random_reverse,
max_frame=99,
bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
interval = random.choice(interval_list)
if bordermode:
direction = 1
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > max_frame:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
if direction == 1:
neighbor_list = list(
range(center_frame_idx,
center_frame_idx + interval * number_frames,
interval))
else:
neighbor_list = list(
range(center_frame_idx,
center_frame_idx - interval * number_frames,
-interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval > max_frame) or (
center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, max_frame)
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1,
interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(self, path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(self, img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def get_test_neighbor_frames(self, crt_i, N, max_n=100, padding='new_info'):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
print(return_l)
name_b = '{:08d}'.format(crt_i)
return return_l, name_b
def __len__(self):
"""Return the total number of images in the dataset.
"""
return len(self.filelist)
......@@ -27,3 +27,4 @@ from .styleganv2_model import StyleGAN2Model
from .wav2lip_model import Wav2LipModel
from .wav2lip_hq_model import Wav2LipModelHq
from .starganv2_model import StarGANv2Model
from .edvr_model import EDVRModel
from .gan_loss import GANLoss
from .perceptual_loss import PerceptualLoss
from .pixel_loss import L1Loss, MSELoss
from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss
from .builder import build_criterion
......@@ -49,6 +49,27 @@ class L1Loss():
return self.loss_weight * self._l1_loss(pred, target)
@CRITERIONS.register()
class CharbonnierLoss():
"""Charbonnier Loss (L1).
Args:
eps (float): Default: 1e-12.
"""
def __init__(self, eps=1e-12):
self.eps = eps
def __call__(self, pred, target, **kwargs):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
return paddle.sum(paddle.sqrt((pred - target)**2 + self.eps))
@CRITERIONS.register()
class MSELoss():
"""MSE (L2) loss.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .builder import MODELS
from .sr_model import BaseSRModel
from .generators.edvr import ResidualBlockNoBN
from ..modules.init import reset_parameters
@MODELS.register()
class EDVRModel(BaseSRModel):
"""EDVR Model.
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
"""
def __init__(self, generator, tsa_iter, pixel_criterion=None):
"""Initialize the EDVR class.
Args:
generator (dict): config of generator.
tsa_iter (dict): config of tsa_iter.
pixel_criterion (dict): config of pixel criterion.
"""
super(EDVRModel, self).__init__(generator, pixel_criterion)
self.tsa_iter = tsa_iter
self.current_iter = 1
init_edvr_weight(self.nets['generator'])
def setup_input(self, input):
self.lq = paddle.to_tensor(input['lq'])
self.visual_items['lq'] = self.lq[:, 2, :, :, :]
self.visual_items['lq-2'] = self.lq[:, 0, :, :, :]
self.visual_items['lq-1'] = self.lq[:, 1, :, :, :]
self.visual_items['lq+1'] = self.lq[:, 3, :, :, :]
self.visual_items['lq+2'] = self.lq[:, 4, :, :, :]
if 'gt' in input:
self.gt = paddle.to_tensor(input['gt'])
self.visual_items['gt'] = self.gt
self.image_paths = input['lq_path']
def train_iter(self, optims=None):
optims['optim'].clear_grad()
if self.tsa_iter:
if self.current_iter == 1:
print('Only train TSA module for', self.tsa_iter, 'iters.')
for name, param in self.nets['generator'].named_parameters():
if 'TSAModule' not in name:
param.trainable = False
elif self.current_iter == self.tsa_iter + 1:
print('Train all the parameters.')
for param in self.nets['generator'].parameters():
param.trainable = True
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
# pixel loss
loss_pixel = self.pixel_criterion(self.output, self.gt)
self.losses['loss_pixel'] = loss_pixel
loss_pixel.backward()
optims['optim'].step()
self.current_iter += 1
def init_edvr_weight(net):
def reset_func(m):
if hasattr(m,
'weight') and (not isinstance(m,
(nn.BatchNorm, nn.BatchNorm2D))
) and (not isinstance(m, ResidualBlockNoBN)):
reset_parameters(m)
net.apply(reset_func)
......@@ -27,4 +27,5 @@ from .generator_styleganv2 import StyleGANv2Generator
from .generator_pixel2style2pixel import Pixel2Style2Pixel
from .drn import DRNGenerator
from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN
from .edvr import EDVRNet
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import scipy.io as scio
import paddle.nn as nn
from paddle.nn import initializer
from ...modules.init import kaiming_normal_, constant_
from ...modules.dcn import DeformableConv_dygraph
# from paddle.vision.ops import DeformConv2D #to be compiled
from .builder import GENERATORS
@paddle.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for m in module_list:
if isinstance(m, nn.Conv2D):
kaiming_normal_(m.weight, **kwargs)
scale_weight = scale * m.weight
m.weight.set_value(scale_weight)
if m.bias is not None:
constant_(m.bias, bias_fill)
elif isinstance(m, nn.Linear):
kaiming_normal_(m.weight, **kwargs)
scale_weight = scale * m.weight
m.weight.set_value(scale_weight)
if m.bias is not None:
constant_(m.bias, bias_fill)
class ResidualBlockNoBN(nn.Layer):
"""Residual block without BN.
It has a style of:
---Conv-ReLU-Conv-+-
|________________|
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, nf=64):
super(ResidualBlockNoBN, self).__init__()
self.nf = nf
self.conv1 = nn.Conv2D(self.nf, self.nf, 3, 1, 1)
self.conv2 = nn.Conv2D(self.nf, self.nf, 3, 1, 1)
self.relu = nn.ReLU()
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out
def MakeMultiBlocks(func, num_layers, nf=64):
"""Make layers by stacking the same blocks.
Args:
func (nn.Layer): nn.Layer class for basic block.
num_layers (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
Blocks = nn.Sequential()
for i in range(num_layers):
Blocks.add_sublayer('block%d' % i, func(nf))
return Blocks
class PredeblurResNetPyramid(nn.Layer):
"""Pre-dublur module.
Args:
in_nf (int): Channel number of input image. Default: 3.
nf (int): Channel number of intermediate features. Default: 64.
HR_in (bool): Whether the input has high resolution. Default: False.
"""
def __init__(self, in_nf=3, nf=64, HR_in=False):
super(PredeblurResNetPyramid, self).__init__()
self.in_nf = in_nf
self.nf = nf
self.HR_in = True if HR_in else False
self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1)
if self.HR_in:
self.conv_first_1 = nn.Conv2D(in_channels=self.in_nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.conv_first_2 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
self.conv_first_3 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
else:
self.conv_first = nn.Conv2D(in_channels=self.in_nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.RB_L1_1 = ResidualBlockNoBN(nf=self.nf)
self.RB_L1_2 = ResidualBlockNoBN(nf=self.nf)
self.RB_L1_3 = ResidualBlockNoBN(nf=self.nf)
self.RB_L1_4 = ResidualBlockNoBN(nf=self.nf)
self.RB_L1_5 = ResidualBlockNoBN(nf=self.nf)
self.RB_L2_1 = ResidualBlockNoBN(nf=self.nf)
self.RB_L2_2 = ResidualBlockNoBN(nf=self.nf)
self.RB_L3_1 = ResidualBlockNoBN(nf=self.nf)
self.deblur_L2_conv = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
self.deblur_L3_conv = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
self.upsample = nn.Upsample(scale_factor=2,
mode="bilinear",
align_corners=False,
align_mode=0)
def forward(self, x):
if self.HR_in:
L1_fea = self.Leaky_relu(self.conv_first_1(x))
L1_fea = self.Leaky_relu(self.conv_first_2(L1_fea))
L1_fea = self.Leaky_relu(self.conv_first_3(L1_fea))
else:
L1_fea = self.Leaky_relu(self.conv_first(x))
L2_fea = self.deblur_L2_conv(L1_fea)
L2_fea = self.Leaky_relu(L2_fea)
L3_fea = self.deblur_L3_conv(L2_fea)
L3_fea = self.Leaky_relu(L3_fea)
L3_fea = self.RB_L3_1(L3_fea)
L3_fea = self.upsample(L3_fea)
L2_fea = self.RB_L2_1(L2_fea) + L3_fea
L2_fea = self.RB_L2_2(L2_fea)
L2_fea = self.upsample(L2_fea)
L1_fea = self.RB_L1_1(L1_fea)
L1_fea = self.RB_L1_2(L1_fea) + L2_fea
out = self.RB_L1_3(L1_fea)
out = self.RB_L1_4(out)
out = self.RB_L1_5(out)
return out
class TSAFusion(nn.Layer):
"""Temporal Spatial Attention (TSA) fusion module.
Temporal: Calculate the correlation between center frame and
neighboring frames;
Spatial: It has 3 pyramid levels, the attention is similar to SFT.
(SFT: Recovering realistic texture in image super-resolution by deep
spatial feature transform.)
Args:
nf (int): Channel number of middle features. Default: 64.
nframes (int): Number of frames. Default: 5.
center (int): The index of center frame. Default: 2.
"""
def __init__(self, nf=64, nframes=5, center=2):
super(TSAFusion, self).__init__()
self.nf = nf
self.nframes = nframes
self.center = center
self.sigmoid = nn.Sigmoid()
self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1)
self.tAtt_2 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.tAtt_1 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.fea_fusion = nn.Conv2D(in_channels=self.nf * self.nframes,
out_channels=self.nf,
kernel_size=1,
stride=1,
padding=0)
self.sAtt_1 = nn.Conv2D(in_channels=self.nf * self.nframes,
out_channels=self.nf,
kernel_size=1,
stride=1,
padding=0)
self.max_pool = nn.MaxPool2D(3, stride=2, padding=1)
self.avg_pool = nn.AvgPool2D(3, stride=2, padding=1, exclusive=False)
self.sAtt_2 = nn.Conv2D(in_channels=2 * self.nf,
out_channels=self.nf,
kernel_size=1,
stride=1,
padding=0)
self.sAtt_3 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.sAtt_4 = nn.Conv2D(
in_channels=self.nf,
out_channels=self.nf,
kernel_size=1,
stride=1,
padding=0,
)
self.sAtt_5 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.sAtt_add_1 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=1,
stride=1,
padding=0)
self.sAtt_add_2 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=1,
stride=1,
padding=0)
self.sAtt_L1 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=1,
stride=1,
padding=0)
self.sAtt_L2 = nn.Conv2D(
in_channels=2 * self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1,
)
self.sAtt_L3 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.upsample = nn.Upsample(scale_factor=2,
mode="bilinear",
align_corners=False,
align_mode=0)
def forward(self, aligned_fea):
"""
Args:
aligned_feat (Tensor): Aligned features with shape (b, n, c, h, w).
Returns:
Tensor: Features after TSA with the shape (b, c, h, w).
"""
B, N, C, H, W = aligned_fea.shape
x_center = aligned_fea[:, self.center, :, :, :]
emb_rf = self.tAtt_2(x_center)
emb = aligned_fea.reshape([-1, C, H, W])
emb = self.tAtt_1(emb)
emb = emb.reshape([-1, N, self.nf, H, W])
cor_l = []
for i in range(N):
emb_nbr = emb[:, i, :, :, :] #[B,C,W,H]
cor_tmp = paddle.sum(emb_nbr * emb_rf, axis=1)
cor_tmp = paddle.unsqueeze(cor_tmp, axis=1)
cor_l.append(cor_tmp)
cor_prob = paddle.concat(cor_l, axis=1) #[B,N,H,W]
cor_prob = self.sigmoid(cor_prob)
cor_prob = paddle.unsqueeze(cor_prob, axis=2) #[B,N,1,H,W]
cor_prob = paddle.expand(cor_prob, [B, N, self.nf, H, W]) #[B,N,C,H,W]
cor_prob = cor_prob.reshape([B, -1, H, W])
aligned_fea = aligned_fea.reshape([B, -1, H, W])
aligned_fea = aligned_fea * cor_prob
fea = self.fea_fusion(aligned_fea)
fea = self.Leaky_relu(fea)
#spatial fusion
att = self.sAtt_1(aligned_fea)
att = self.Leaky_relu(att)
att_max = self.max_pool(att)
att_avg = self.avg_pool(att)
att_pool = paddle.concat([att_max, att_avg], axis=1)
att = self.sAtt_2(att_pool)
att = self.Leaky_relu(att)
#pyramid
att_L = self.sAtt_L1(att)
att_L = self.Leaky_relu(att_L)
att_max = self.max_pool(att_L)
att_avg = self.avg_pool(att_L)
att_pool = paddle.concat([att_max, att_avg], axis=1)
att_L = self.sAtt_L2(att_pool)
att_L = self.Leaky_relu(att_L)
att_L = self.sAtt_L3(att_L)
att_L = self.Leaky_relu(att_L)
att_L = self.upsample(att_L)
att = self.sAtt_3(att)
att = self.Leaky_relu(att)
att = att + att_L
att = self.sAtt_4(att)
att = self.Leaky_relu(att)
att = self.upsample(att)
att = self.sAtt_5(att)
att_add = self.sAtt_add_1(att)
att_add = self.Leaky_relu(att_add)
att_add = self.sAtt_add_2(att_add)
att = self.sigmoid(att)
fea = fea * att * 2 + att_add
return fea
class DCNPack(nn.Layer):
"""Modulated deformable conv for deformable alignment.
Ref:
Delving Deep into Deformable Alignment in Video Super-Resolution.
"""
def __init__(self,
num_filters=64,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
deformable_groups=8,
extra_offset_mask=True):
super(DCNPack, self).__init__()
self.extra_offset_mask = extra_offset_mask
self.deformable_groups = deformable_groups
self.num_filters = num_filters
if isinstance(kernel_size, int):
self.kernel_size = [kernel_size, kernel_size]
self.conv_offset_mask = nn.Conv2D(in_channels=self.num_filters,
out_channels=self.deformable_groups *
3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=stride,
padding=padding)
self.total_channels = self.deformable_groups * 3 * self.kernel_size[
0] * self.kernel_size[1]
self.split_channels = self.total_channels // 3
self.dcn = DeformableConv_dygraph(
num_filters=self.num_filters,
filter_size=self.kernel_size,
dilation=dilation,
stride=stride,
padding=padding,
deformable_groups=self.deformable_groups)
# self.dcn = DeformConv2D(in_channels=self.num_filters,out_channels=self.num_filters,kernel_size=self.kernel_size,stride=stride,padding=padding,dilation=dilation,deformable_groups=self.deformable_groups,groups=1) # to be compiled
self.sigmoid = nn.Sigmoid()
def forward(self, fea_and_offset):
out = None
x = None
if self.extra_offset_mask:
out = self.conv_offset_mask(fea_and_offset[1])
x = fea_and_offset[0]
o1 = out[:, 0:self.split_channels, :, :]
o2 = out[:, self.split_channels:2 * self.split_channels, :, :]
mask = out[:, 2 * self.split_channels:, :, :]
offset = paddle.concat([o1, o2], axis=1)
mask = self.sigmoid(mask)
y = self.dcn(x, offset, mask)
return y
class PCDAlign(nn.Layer):
"""Alignment module using Pyramid, Cascading and Deformable convolution
(PCD). It is used in EDVR.
Ref:
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
Args:
nf (int): Channel number of middle features. Default: 64.
groups (int): Deformable groups. Defaults: 8.
"""
def __init__(self, nf=64, groups=8):
super(PCDAlign, self).__init__()
self.nf = nf
self.groups = groups
self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1)
self.upsample = nn.Upsample(scale_factor=2,
mode="bilinear",
align_corners=False,
align_mode=0)
# Pyramid has three levels:
# L3: level 3, 1/4 spatial size
# L2: level 2, 1/2 spatial size
# L1: level 1, original spatial size
# L3
self.PCD_Align_L3_offset_conv1 = nn.Conv2D(in_channels=nf * 2,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_L3_offset_conv2 = nn.Conv2D(in_channels=nf,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_L3_dcn = DCNPack(num_filters=nf,
kernel_size=3,
stride=1,
padding=1,
deformable_groups=groups)
#L2
self.PCD_Align_L2_offset_conv1 = nn.Conv2D(in_channels=nf * 2,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_L2_offset_conv2 = nn.Conv2D(in_channels=nf * 2,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_L2_offset_conv3 = nn.Conv2D(in_channels=nf,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_L2_dcn = DCNPack(num_filters=nf,
kernel_size=3,
stride=1,
padding=1,
deformable_groups=groups)
self.PCD_Align_L2_fea_conv = nn.Conv2D(in_channels=nf * 2,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
#L1
self.PCD_Align_L1_offset_conv1 = nn.Conv2D(in_channels=nf * 2,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_L1_offset_conv2 = nn.Conv2D(in_channels=nf * 2,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_L1_offset_conv3 = nn.Conv2D(in_channels=nf,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_L1_dcn = DCNPack(num_filters=nf,
kernel_size=3,
stride=1,
padding=1,
deformable_groups=groups)
self.PCD_Align_L1_fea_conv = nn.Conv2D(in_channels=nf * 2,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
#cascade
self.PCD_Align_cas_offset_conv1 = nn.Conv2D(in_channels=nf * 2,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_cas_offset_conv2 = nn.Conv2D(in_channels=nf,
out_channels=nf,
kernel_size=3,
stride=1,
padding=1)
self.PCD_Align_cascade_dcn = DCNPack(num_filters=nf,
kernel_size=3,
stride=1,
padding=1,
deformable_groups=groups)
def forward(self, nbr_fea_l, ref_fea_l):
"""Align neighboring frame features to the reference frame features.
Args:
nbr_fea_l (list[Tensor]): Neighboring feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
ref_fea_l (list[Tensor]): Reference feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
Returns:
Tensor: Aligned features.
"""
#L3
L3_offset = paddle.concat([nbr_fea_l[2], ref_fea_l[2]], axis=1)
L3_offset = self.PCD_Align_L3_offset_conv1(L3_offset)
L3_offset = self.Leaky_relu(L3_offset)
L3_offset = self.PCD_Align_L3_offset_conv2(L3_offset)
L3_offset = self.Leaky_relu(L3_offset)
L3_fea = self.PCD_Align_L3_dcn([nbr_fea_l[2], L3_offset])
L3_fea = self.Leaky_relu(L3_fea)
#L2
L2_offset = paddle.concat([nbr_fea_l[1], ref_fea_l[1]], axis=1)
L2_offset = self.PCD_Align_L2_offset_conv1(L2_offset)
L2_offset = self.Leaky_relu(L2_offset)
L3_offset = self.upsample(L3_offset)
L2_offset = paddle.concat([L2_offset, L3_offset * 2], axis=1)
L2_offset = self.PCD_Align_L2_offset_conv2(L2_offset)
L2_offset = self.Leaky_relu(L2_offset)
L2_offset = self.PCD_Align_L2_offset_conv3(L2_offset)
L2_offset = self.Leaky_relu(L2_offset)
L2_fea = self.PCD_Align_L2_dcn([nbr_fea_l[1], L2_offset])
L3_fea = self.upsample(L3_fea)
L2_fea = paddle.concat([L2_fea, L3_fea], axis=1)
L2_fea = self.PCD_Align_L2_fea_conv(L2_fea)
L2_fea = self.Leaky_relu(L2_fea)
#L1
L1_offset = paddle.concat([nbr_fea_l[0], ref_fea_l[0]], axis=1)
L1_offset = self.PCD_Align_L1_offset_conv1(L1_offset)
L1_offset = self.Leaky_relu(L1_offset)
L2_offset = self.upsample(L2_offset)
L1_offset = paddle.concat([L1_offset, L2_offset * 2], axis=1)
L1_offset = self.PCD_Align_L1_offset_conv2(L1_offset)
L1_offset = self.Leaky_relu(L1_offset)
L1_offset = self.PCD_Align_L1_offset_conv3(L1_offset)
L1_offset = self.Leaky_relu(L1_offset)
L1_fea = self.PCD_Align_L1_dcn([nbr_fea_l[0], L1_offset])
L2_fea = self.upsample(L2_fea)
L1_fea = paddle.concat([L1_fea, L2_fea], axis=1)
L1_fea = self.PCD_Align_L1_fea_conv(L1_fea)
#cascade
offset = paddle.concat([L1_fea, ref_fea_l[0]], axis=1)
offset = self.PCD_Align_cas_offset_conv1(offset)
offset = self.Leaky_relu(offset)
offset = self.PCD_Align_cas_offset_conv2(offset)
offset = self.Leaky_relu(offset)
L1_fea = self.PCD_Align_cascade_dcn([L1_fea, offset])
L1_fea = self.Leaky_relu(L1_fea)
return L1_fea
@GENERATORS.register()
class EDVRNet(nn.Layer):
"""EDVR network structure for video super-resolution.
Now only support X4 upsampling factor.
Paper:
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
Args:
in_nf (int): Channel number of input image. Default: 3.
out_nf (int): Channel number of output image. Default: 3.
scale_factor (int): Scale factor from input image to output image. Default: 4.
nf (int): Channel number of intermediate features. Default: 64.
nframes (int): Number of input frames. Default: 5.
groups (int): Deformable groups. Defaults: 8.
front_RBs (int): Number of blocks for feature extraction. Default: 5.
back_RBs (int): Number of blocks for reconstruction. Default: 10.
center (int): The index of center frame. Frame counting from 0. Default: None.
predeblur (bool): Whether has predeblur module. Default: False.
HR_in (bool): Whether the input has high resolution. Default: False.
with_tsa (bool): Whether has TSA module. Default: True.
TSA_only (bool): Whether only use TSA module. Default: False.
"""
def __init__(self,
in_nf=3,
out_nf=3,
scale_factor=4,
nf=64,
nframes=5,
groups=8,
front_RBs=5,
back_RBs=10,
center=None,
predeblur=False,
HR_in=False,
w_TSA=True,
TSA_only=False):
super(EDVRNet, self).__init__()
self.in_nf = in_nf
self.out_nf = out_nf
self.scale_factor = scale_factor
self.nf = nf
self.nframes = nframes
self.groups = groups
self.front_RBs = front_RBs
self.back_RBs = back_RBs
self.center = nframes // 2 if center is None else center
self.predeblur = True if predeblur else False
self.HR_in = True if HR_in else False
self.w_TSA = True if w_TSA else False
self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1)
if self.predeblur:
self.pre_deblur = PredeblurResNetPyramid(in_nf=self.in_nf,
nf=self.nf,
HR_in=self.HR_in)
self.cov_1 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=1,
stride=1)
else:
if self.HR_in:
self.conv_first_1 = nn.Conv2D(in_channels=self.in_nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.conv_first_2 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
self.conv_first_3 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
else:
self.conv_first = nn.Conv2D(in_channels=self.in_nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
#feature extraction module
self.feature_extractor = MakeMultiBlocks(ResidualBlockNoBN,
self.front_RBs, self.nf)
self.fea_L2_conv1 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1)
self.fea_L2_conv2 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.fea_L3_conv1 = nn.Conv2D(
in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=2,
padding=1,
)
self.fea_L3_conv2 = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
#PCD alignment module
self.PCDModule = PCDAlign(nf=self.nf, groups=self.groups)
#TSA Fusion module
if self.w_TSA:
self.TSAModule = TSAFusion(nf=self.nf,
nframes=self.nframes,
center=self.center)
else:
self.TSAModule = nn.Conv2D(in_channels=self.nframes * self.nf,
out_channels=self.nf,
kernel_size=1,
stride=1)
#reconstruction module
self.reconstructor = MakeMultiBlocks(ResidualBlockNoBN, self.back_RBs,
self.nf)
self.upconv1 = nn.Conv2D(in_channels=self.nf,
out_channels=4 * self.nf,
kernel_size=3,
stride=1,
padding=1)
self.pixel_shuffle = nn.PixelShuffle(2)
self.upconv2 = nn.Conv2D(in_channels=self.nf,
out_channels=4 * self.nf,
kernel_size=3,
stride=1,
padding=1)
self.HRconv = nn.Conv2D(in_channels=self.nf,
out_channels=self.nf,
kernel_size=3,
stride=1,
padding=1)
self.conv_last = nn.Conv2D(in_channels=self.nf,
out_channels=self.out_nf,
kernel_size=3,
stride=1,
padding=1)
self.upsample = nn.Upsample(scale_factor=self.scale_factor,
mode="bilinear",
align_corners=False,
align_mode=0)
def forward(self, x):
"""
Args:
x (Tensor): Input features with shape (b, n, c, h, w).
Returns:
Tensor: Features after EDVR with the shape (b, c, scale_factor*h, scale_factor*w).
"""
B, N, C, H, W = x.shape
x_center = x[:, self.center, :, :, :]
L1_fea = x.reshape([-1, C, H, W]) #[B*N,C,W,H]
if self.predeblur:
L1_fea = self.pre_deblur(L1_fea)
L1_fea = self.cov_1(L1_fea)
if self.HR_in:
H, W = H // self.scale_factor, W // self.scale_factor
else:
if self.HR_in:
L1_fea = self.conv_first_1(L1_fea)
L1_fea = self.Leaky_relu(L1_fea)
L1_fea = self.conv_first_2(L1_fea)
L1_fea = self.Leaky_relu(L1_fea)
L1_fea = self.conv_first_3(L1_fea)
L1_fea = self.Leaky_relu(L1_fea)
H = H // self.scale_factor
W = W // self.scale_factor
else:
L1_fea = self.conv_first(L1_fea)
L1_fea = self.Leaky_relu(L1_fea)
# feature extraction and create Pyramid
L1_fea = self.feature_extractor(L1_fea)
# L2
L2_fea = self.fea_L2_conv1(L1_fea)
L2_fea = self.Leaky_relu(L2_fea)
L2_fea = self.fea_L2_conv2(L2_fea)
L2_fea = self.Leaky_relu(L2_fea)
# L3
L3_fea = self.fea_L3_conv1(L2_fea)
L3_fea = self.Leaky_relu(L3_fea)
L3_fea = self.fea_L3_conv2(L3_fea)
L3_fea = self.Leaky_relu(L3_fea)
L1_fea = L1_fea.reshape([-1, N, self.nf, H, W])
L2_fea = L2_fea.reshape([-1, N, self.nf, H // 2, W // 2])
L3_fea = L3_fea.reshape([-1, N, self.nf, H // 4, W // 4])
# pcd align
ref_fea_l = [
L1_fea[:, self.center, :, :, :], L2_fea[:, self.center, :, :, :],
L3_fea[:, self.center, :, :, :]
]
aligned_fea = []
for i in range(N):
nbr_fea_l = [
L1_fea[:, i, :, :, :], L2_fea[:, i, :, :, :], L3_fea[:,
i, :, :, :]
]
aligned_fea.append(self.PCDModule(nbr_fea_l, ref_fea_l))
# TSA Fusion
aligned_fea = paddle.stack(aligned_fea, axis=1) # [B, N, C, H, W]
fea = None
if not self.w_TSA:
aligned_fea = aligned_fea.reshape([B, -1, H, W])
fea = self.TSAModule(aligned_fea) # [B, N, C, H, W]
#Reconstruct
out = self.reconstructor(fea)
out = self.upconv1(out)
out = self.pixel_shuffle(out)
out = self.Leaky_relu(out)
out = self.upconv2(out)
out = self.pixel_shuffle(out)
out = self.Leaky_relu(out)
out = self.HRconv(out)
out = self.Leaky_relu(out)
out = self.conv_last(out)
if self.HR_in:
base = x_center
else:
base = self.upsample(x_center)
out += base
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
from paddle.fluid.layers import deformable_conv
from paddle.fluid import core, layers
from paddle.fluid.layers import nn, utils
from paddle.nn import Layer
from paddle.fluid.initializer import Normal
from paddle.common_ops_import import *
class DeformConv2D(Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
deformable_groups=1,
groups=1,
weight_attr=None,
bias_attr=None):
super(DeformConv2D, self).__init__()
assert weight_attr is not False, "weight_attr should not be False in Conv."
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._deformable_groups = deformable_groups
self._groups = groups
self._in_channels = in_channels
self._out_channels = out_channels
self.padding = padding
self.stride = stride
self._channel_dim = 1
self._stride = utils.convert_to_list(stride, 2, 'stride')
self._dilation = utils.convert_to_list(dilation, 2, 'dilation')
self._kernel_size = utils.convert_to_list(kernel_size, 2, 'kernel_size')
if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups.")
self._padding = utils.convert_to_list(padding, 2, 'padding')
filter_shape = [out_channels, in_channels // groups] + self._kernel_size
def _get_default_param_initializer():
filter_elem_num = np.prod(self._kernel_size) * self._in_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
self.weight = self.create_parameter(
shape=filter_shape,
attr=self._weight_attr,
default_initializer=_get_default_param_initializer())
self.bias = self.create_parameter(
attr=self._bias_attr, shape=[self._out_channels], is_bias=True)
def forward(self, x, offset, mask):
out = deform_conv2d(
x=x,
offset=offset,
mask=mask,
weight=self.weight,
bias=self.bias,
stride=self._stride,
padding=self._padding,
dilation=self._dilation,
deformable_groups=self._deformable_groups,
groups=self._groups,
)
return out
def deform_conv2d(x,
offset,
weight,
mask,
bias=None,
stride=1,
padding=0,
dilation=1,
deformable_groups=1,
groups=1,
name=None):
stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
use_deform_conv2d_v1 = True if mask is None else False
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'deformable_groups',deformable_groups,
'groups', groups, 'im2col_step', 1)
if use_deform_conv2d_v1:
op_type = 'deformable_conv_v1'
pre_bias = getattr(core.ops, op_type)(x, offset, weight, *attrs)
else:
op_type = 'deformable_conv'
pre_bias = getattr(core.ops, op_type)(x, offset, mask, weight,
*attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=1)
else:
out = pre_bias
return out
class DeformableConv_dygraph(Layer):
def __init__(self,num_filters,filter_size,dilation,
stride,padding,deformable_groups=1,groups=1):
super(DeformableConv_dygraph, self).__init__()
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.stride = stride
self.padding = padding
self.deformable_groups = deformable_groups
self.groups = groups
self.defor_conv = DeformConv2D(in_channels=self.num_filters, out_channels=self.num_filters,
kernel_size=self.filter_size, stride=self.stride, padding=self.padding,
dilation=self.dilation, deformable_groups=self.deformable_groups, groups=self.groups, weight_attr=None, bias_attr=None)
def forward(self,*input):
x = input[0]
offset = input[1]
mask = input[2]
out = self.defor_conv(x, offset, mask)
return out
......@@ -324,3 +324,10 @@ def init_weights(net,
logger = get_logger()
logger.debug('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def reset_parameters(m):
kaiming_uniform_(m.weight, a=math.sqrt(5))
if m.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
uniform_(m.bias, -bound, bound)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册