未验证 提交 7131e67c 编写于 作者: 农夫三拳_'s avatar 农夫三拳_ 提交者: GitHub

add mpr deblur derain and denoise pretrained models for image restoration application (#352)

* add mpr deblur derain and denoise pretrained models for image restoration application
上级 97f96b94
import paddle
import os
import sys
sys.path.insert(0, os.getcwd())
from ppgan.apps import MPRPredictor
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--seed",
type=int,
default=None,
help="sample random seed for model's image generation")
parser.add_argument('--images_path',
default=None,
required=True,
type=str,
help='Single image or images directory.')
parser.add_argument('--task',
required=True,
type=str,
help='Task to run',
choices=['Deblurring', 'Denoising', 'Deraining'])
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = MPRPredictor(
images_path=args.images_path,
output_path=args.output_path,
weight_path=args.weight_path,
seed=args.seed,
task=args.task
)
predictor.run()
total_iters: 100000
output_dir: output_dir
model:
name: MPRModel
generator:
name: MPRNet
char_criterion:
name: CharbonnierLoss
edge_criterion:
name: EdgeLoss
dataset:
train:
name: MPRTrain
rgb_dir: 'data/GoPro/train'
num_workers: 16
batch_size: 4
img_options:
patch_size: 256
test:
name: MPRTrain
rgb_dir: 'data/GoPro/test'
num_workers: 16
batch_size: 4
img_options:
patch_size: 256
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [25000, 25000, 25000, 25000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-6
validate:
interval: 10
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
...@@ -25,3 +25,4 @@ from .photo2cartoon_predictor import Photo2CartoonPredictor ...@@ -25,3 +25,4 @@ from .photo2cartoon_predictor import Photo2CartoonPredictor
from .styleganv2_predictor import StyleGANv2Predictor from .styleganv2_predictor import StyleGANv2Predictor
from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor
from .wav2lip_predictor import Wav2LipPredictor from .wav2lip_predictor import Wav2LipPredictor
from .mpr_predictor import MPRPredictor
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
from natsort import natsorted
from glob import glob
import numpy as np
import cv2
from PIL import Image
import paddle
from .base_predictor import BasePredictor
from ppgan.models.generators import MPRNet
from ppgan.utils.download import get_path_from_url
from ppgan.utils.visual import make_grid, tensor2img, save_image
from ppgan.datasets.mpr_dataset import to_tensor
from paddle.vision.transforms import Pad
from tqdm import tqdm
model_cfgs = {
'Deblurring': {
'model_urls':
'https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams',
'n_feat': 96,
'scale_unetfeats': 48,
'scale_orsnetfeats': 32,
},
'Denoising': {
'model_urls':
'https://paddlegan.bj.bcebos.com/models/MPR_Denoising.pdparams',
'n_feat': 80,
'scale_unetfeats': 48,
'scale_orsnetfeats': 32,
},
'Deraining': {
'model_urls':
'https://paddlegan.bj.bcebos.com/models/MPR_Deraining.pdparams',
'n_feat': 40,
'scale_unetfeats': 20,
'scale_orsnetfeats': 16,
}
}
class MPRPredictor(BasePredictor):
def __init__(self,
images_path=None,
output_path='output_dir',
weight_path=None,
seed=None,
task=None):
self.output_path = output_path
self.images_path = images_path
self.task = task
self.max_size = 640
self.img_multiple_of = 8
if weight_path is None:
if task in model_cfgs.keys():
weight_path = get_path_from_url(model_cfgs[task]['model_urls'])
checkpoint = paddle.load(weight_path)
else:
raise ValueError(
'Predictor need a weight path or a pretrained model type')
else:
checkpoint = paddle.load(weight_path)
self.generator = MPRNet(
n_feat=model_cfgs[task]['n_feat'],
scale_unetfeats=model_cfgs[task]['scale_unetfeats'],
scale_orsnetfeats=model_cfgs[task]['scale_orsnetfeats'])
self.generator.set_state_dict(checkpoint)
self.generator.eval()
if seed is not None:
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_images(self, images_path):
if os.path.isdir(images_path):
return natsorted(
glob(os.path.join(images_path, '*.jpg')) +
glob(os.path.join(images_path, '*.JPG')) +
glob(os.path.join(images_path, '*.png')) +
glob(os.path.join(images_path, '*.PNG')))
else:
return [images_path]
def read_image(self, image_file):
img = Image.open(image_file).convert('RGB')
max_length = max(img.width, img.height)
if max_length > self.max_size:
ratio = max_length / self.max_size
dw = int(img.width / ratio)
dh = int(img.height / ratio)
img = img.resize((dw, dh))
return img
def run(self):
os.makedirs(self.output_path, exist_ok=True)
task_path = os.path.join(self.output_path, self.task)
os.makedirs(task_path, exist_ok=True)
image_files = self.get_images(self.images_path)
for image_file in tqdm(image_files):
img = self.read_image(image_file)
image_name = os.path.basename(image_file)
img.save(os.path.join(task_path, image_name))
tmps = image_name.split('.')
assert len(
tmps) == 2, f'Invalid image name: {image_name}, too much "."'
restoration_save_path = os.path.join(
task_path, f'{tmps[0]}_restoration.{tmps[1]}')
input_ = to_tensor(img)
# Pad the input if not_multiple_of 8
h, w = input_.shape[1], input_.shape[2]
H, W = ((h + self.img_multiple_of) //
self.img_multiple_of) * self.img_multiple_of, (
(w + self.img_multiple_of) //
self.img_multiple_of) * self.img_multiple_of
padh = H - h if h % self.img_multiple_of != 0 else 0
padw = W - w if w % self.img_multiple_of != 0 else 0
input_ = paddle.to_tensor(input_)
transform = Pad((0, 0, padw, padh), padding_mode='reflect')
input_ = transform(input_)
input_ = paddle.to_tensor(np.expand_dims(input_.numpy(), 0))
with paddle.no_grad():
restored = self.generator(input_)
restored = restored[0]
restored = paddle.clip(restored, 0, 1)
# Unpad the output
restored = restored[:, :, :h, :w]
restored = restored.numpy()
restored = restored.transpose(0, 2, 3, 1)
restored = restored[0]
restored = restored * 255
restored = restored.astype(np.uint8)
cv2.imwrite(restoration_save_path,
cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
print('Done, output path is:', task_path)
...@@ -24,3 +24,4 @@ from .starganv2_dataset import StarGANv2Dataset ...@@ -24,3 +24,4 @@ from .starganv2_dataset import StarGANv2Dataset
from .edvr_dataset import REDSDataset from .edvr_dataset import REDSDataset
from .firstorder_dataset import FirstOrderDataset from .firstorder_dataset import FirstOrderDataset
from .lapstyle_dataset import LapStyleDataset from .lapstyle_dataset import LapStyleDataset
from .mpr_dataset import MPRTrain, MPRVal, MPRTest
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import cv2
import paddle
from PIL import Image, ImageEnhance
import numpy as np
import random
import numbers
from paddle.io import Dataset
from .builder import DATASETS
from paddle.vision.transforms.functional import to_tensor, adjust_brightness, adjust_saturation, rotate, hflip, hflip, vflip, center_crop
def is_image_file(filename):
return any(
filename.endswith(extension)
for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
@DATASETS.register()
class MPRTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(MPRTrain, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [
os.path.join(rgb_dir, 'input', x) for x in inp_files
if is_image_file(x)
]
self.tar_filenames = [
os.path.join(rgb_dir, 'target', x) for x in tar_files
if is_image_file(x)
]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
w, h = tar_img.size
padw = ps - w if w < ps else 0
padh = ps - h if h < ps else 0
# Reflect Pad in case image is smaller than patch_size
if padw != 0 or padh != 0:
inp_img = np.pad(inp_img, (0, 0, padw, padh),
padding_mode='reflect')
tar_img = np.pad(tar_img, (0, 0, padw, padh),
padding_mode='reflect')
aug = random.randint(0, 2)
if aug == 1:
inp_img = adjust_brightness(inp_img, 1)
tar_img = adjust_brightness(tar_img, 1)
aug = random.randint(0, 2)
if aug == 1:
sat_factor = 1 + (0.2 - 0.4 * np.random.rand())
inp_img = adjust_saturation(inp_img, sat_factor)
tar_img = adjust_saturation(tar_img, sat_factor)
# Data Augmentations
if aug == 1:
inp_img = vflip(inp_img)
tar_img = vflip(tar_img)
elif aug == 2:
inp_img = hflip(inp_img)
tar_img = hflip(tar_img)
elif aug == 3:
inp_img = rotate(inp_img, 90)
tar_img = rotate(tar_img, 90)
elif aug == 4:
inp_img = rotate(inp_img, 90 * 2)
tar_img = rotate(tar_img, 90 * 2)
elif aug == 5:
inp_img = rotate(inp_img, 90 * 3)
tar_img = rotate(tar_img, 90 * 3)
elif aug == 6:
inp_img = rotate(vflip(inp_img), 90)
tar_img = rotate(vflip(tar_img), 90)
elif aug == 7:
inp_img = rotate(hflip(inp_img), 90)
tar_img = rotate(hflip(tar_img), 90)
inp_img = to_tensor(inp_img)
tar_img = to_tensor(tar_img)
hh, ww = tar_img.shape[1], tar_img.shape[2]
rr = random.randint(0, hh - ps)
cc = random.randint(0, ww - ps)
aug = random.randint(0, 8)
# Crop patch
inp_img = inp_img[:, rr:rr + ps, cc:cc + ps]
tar_img = tar_img[:, rr:rr + ps, cc:cc + ps]
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
@DATASETS.register()
class MPRVal(Dataset):
def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
super(MPRVal, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [
os.path.join(rgb_dir, 'input', x) for x in inp_files
if is_image_file(x)
]
self.tar_filenames = [
os.path.join(rgb_dir, 'target', x) for x in tar_files
if is_image_file(x)
]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
# Validate on center crop
if self.ps is not None:
inp_img = center_crop(inp_img, (ps, ps))
tar_img = center_crop(tar_img, (ps, ps))
inp_img = to_tensor(inp_img)
tar_img = to_tensor(tar_img)
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
@DATASETS.register()
class MPRTest(Dataset):
def __init__(self, inp_dir, img_options):
super(MPRTest, self).__init__()
inp_files = sorted(os.listdir(inp_dir))
self.inp_filenames = [
os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)
]
self.inp_size = len(self.inp_filenames)
self.img_options = img_options
def __len__(self):
return self.inp_size
def __getitem__(self, index):
path_inp = self.inp_filenames[index]
filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
inp = Image.open(path_inp)
inp = to_tensor(inp)
return inp, filename
...@@ -30,3 +30,4 @@ from .starganv2_model import StarGANv2Model ...@@ -30,3 +30,4 @@ from .starganv2_model import StarGANv2Model
from .edvr_model import EDVRModel from .edvr_model import EDVRModel
from .firstorder_model import FirstOrderModel from .firstorder_model import FirstOrderModel
from .lapstyle_model import LapStyleDraModel, LapStyleRevFirstModel, LapStyleRevSecondModel from .lapstyle_model import LapStyleDraModel, LapStyleRevFirstModel, LapStyleRevSecondModel
from .mpr_model import MPRModel
...@@ -2,6 +2,6 @@ from .gan_loss import GANLoss ...@@ -2,6 +2,6 @@ from .gan_loss import GANLoss
from .perceptual_loss import PerceptualLoss from .perceptual_loss import PerceptualLoss
from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss, \ from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss, \
CalcStyleEmdLoss, CalcContentReltLoss, \ CalcStyleEmdLoss, CalcContentReltLoss, \
CalcContentLoss, CalcStyleLoss CalcContentLoss, CalcStyleLoss, EdgeLoss
from .builder import build_criterion from .builder import build_criterion
...@@ -17,6 +17,7 @@ from ..generators.generater_lapstyle import calc_mean_std, mean_variance_norm ...@@ -17,6 +17,7 @@ from ..generators.generater_lapstyle import calc_mean_std, mean_variance_norm
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F
from .builder import CRITERIONS from .builder import CRITERIONS
...@@ -234,3 +235,29 @@ class CalcStyleLoss(): ...@@ -234,3 +235,29 @@ class CalcStyleLoss():
target_mean, target_std = calc_mean_std(target) target_mean, target_std = calc_mean_std(target)
return self.mse_loss(pred_mean, target_mean) + self.mse_loss( return self.mse_loss(pred_mean, target_mean) + self.mse_loss(
pred_std, target_std) pred_std, target_std)
@CRITERIONS.register()
class EdgeLoss():
def __init__(self):
k = paddle.to_tensor([[.05, .25, .4, .25, .05]])
self.kernel = paddle.matmul(k.t(),k).unsqueeze(0).tile([3,1,1,1])
self.loss = CharbonnierLoss()
def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, [kw//2, kh//2, kw//2, kh//2], mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter
down = filtered[:,:,::2,::2] # downsample
new_filter = paddle.zeros_like(filtered)
new_filter[:,:,::2,::2] = down*4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def __call__(self, x, y):
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
\ No newline at end of file
...@@ -30,3 +30,4 @@ from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Ma ...@@ -30,3 +30,4 @@ from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Ma
from .edvr import EDVRNet from .edvr import EDVRNet
from .generator_firstorder import FirstOrderGenerator from .generator_firstorder import FirstOrderGenerator
from .generater_lapstyle import DecoderNet, Encoder, RevisionNet from .generater_lapstyle import DecoderNet, Encoder, RevisionNet
from .mpr import MPRNet
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import paddle.nn as nn
from ...modules.init import kaiming_normal_, constant_
from ...modules.dcn import DeformableConv_dygraph
# from paddle.vision.ops import DeformConv2D #to be compiled
from .builder import GENERATORS
import paddle
from paddle import nn
import paddle.nn.functional as F
##########################################################################
def conv(in_channels, out_channels, kernel_size, bias_attr=False, stride=1):
return nn.Conv2D(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias_attr=bias_attr, stride = stride)
##########################################################################
## Channel Attention Layer
class CALayer(nn.Layer):
def __init__(self, channel, reduction=16, bias_attr=False):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2D(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2D(channel, channel // reduction, 1, padding=0, bias_attr=bias_attr),
nn.ReLU(),
# nn.ReLU(inplace=True), torch
nn.Conv2D(channel // reduction, channel, 1, padding=0, bias_attr=bias_attr),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
##########################################################################
## Channel Attention Block (CAB)
class CAB(nn.Layer):
def __init__(self, n_feat, kernel_size, reduction, bias_attr, act):
super(CAB, self).__init__()
modules_body = []
modules_body.append(conv(n_feat, n_feat, kernel_size, bias_attr=bias_attr))
modules_body.append(act)
modules_body.append(conv(n_feat, n_feat, kernel_size, bias_attr=bias_attr))
self.CA = CALayer(n_feat, reduction, bias_attr=bias_attr)
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res = self.CA(res)
res += x
return res
##########################################################################
##---------- Resizing Modules ----------
class DownSample(nn.Layer):
def __init__(self, in_channels,s_factor):
super(DownSample, self).__init__()
self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.Conv2D(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias_attr=False))
def forward(self, x):
x = self.down(x)
return x
class UpSample(nn.Layer):
def __init__(self, in_channels,s_factor):
super(UpSample, self).__init__()
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2D(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias_attr=False))
def forward(self, x):
x = self.up(x)
return x
class SkipUpSample(nn.Layer):
def __init__(self, in_channels,s_factor):
super(SkipUpSample, self).__init__()
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2D(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias_attr=False))
def forward(self, x, y):
x = self.up(x)
x = x + y
return x
##########################################################################
## U-Net
class Encoder(nn.Layer):
def __init__(self, n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats, csff):
super(Encoder, self).__init__()
self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)]
self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)]
self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)]
self.encoder_level1 = nn.Sequential(*self.encoder_level1)
self.encoder_level2 = nn.Sequential(*self.encoder_level2)
self.encoder_level3 = nn.Sequential(*self.encoder_level3)
self.down12 = DownSample(n_feat, scale_unetfeats)
self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats)
# Cross Stage Feature Fusion (CSFF)
if csff:
self.csff_enc1 = nn.Conv2D(n_feat, n_feat, kernel_size=1, bias_attr=bias_attr)
self.csff_enc2 = nn.Conv2D(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias_attr=bias_attr)
self.csff_enc3 = nn.Conv2D(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias_attr=bias_attr)
self.csff_dec1 = nn.Conv2D(n_feat, n_feat, kernel_size=1, bias_attr=bias_attr)
self.csff_dec2 = nn.Conv2D(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias_attr=bias_attr)
self.csff_dec3 = nn.Conv2D(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias_attr=bias_attr)
def forward(self, x, encoder_outs=None, decoder_outs=None):
enc1 = self.encoder_level1(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
x = self.down12(enc1)
enc2 = self.encoder_level2(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
x = self.down23(enc2)
enc3 = self.encoder_level3(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
return [enc1, enc2, enc3]
class Decoder(nn.Layer):
def __init__(self, n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats):
super(Decoder, self).__init__()
self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)]
self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)]
self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)]
self.decoder_level1 = nn.Sequential(*self.decoder_level1)
self.decoder_level2 = nn.Sequential(*self.decoder_level2)
self.decoder_level3 = nn.Sequential(*self.decoder_level3)
self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias_attr=bias_attr, act=act)
self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias_attr=bias_attr, act=act)
self.up21 = SkipUpSample(n_feat, scale_unetfeats)
self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats)
def forward(self, outs):
enc1, enc2, enc3 = outs
dec3 = self.decoder_level3(enc3)
x = self.up32(dec3, self.skip_attn2(enc2))
dec2 = self.decoder_level2(x)
x = self.up21(dec2, self.skip_attn1(enc1))
dec1 = self.decoder_level1(x)
return [dec1,dec2,dec3]
##########################################################################
## Original Resolution Block (ORB)
class ORB(nn.Layer):
def __init__(self, n_feat, kernel_size, reduction, act, bias_attr, num_cab):
super(ORB, self).__init__()
modules_body = []
modules_body = [CAB(n_feat, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(num_cab)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
##########################################################################
class ORSNet(nn.Layer):
def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias_attr, scale_unetfeats, num_cab):
super(ORSNet, self).__init__()
self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias_attr, num_cab)
self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias_attr, num_cab)
self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias_attr, num_cab)
self.up_enc1 = UpSample(n_feat, scale_unetfeats)
self.up_dec1 = UpSample(n_feat, scale_unetfeats)
self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats))
self.conv_enc1 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr)
self.conv_enc2 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr)
self.conv_enc3 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr)
self.conv_dec1 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr)
self.conv_dec2 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr)
self.conv_dec3 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr)
def forward(self, x, encoder_outs, decoder_outs):
x = self.orb1(x)
x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0])
x = self.orb2(x)
x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1]))
x = self.orb3(x)
x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2]))
return x
##########################################################################
## Supervised Attention Module
class SAM(nn.Layer):
def __init__(self, n_feat, kernel_size, bias_attr):
super(SAM, self).__init__()
self.conv1 = conv(n_feat, n_feat, kernel_size, bias_attr=bias_attr)
self.conv2 = conv(n_feat, 3, kernel_size, bias_attr=bias_attr)
self.conv3 = conv(3, n_feat, kernel_size, bias_attr=bias_attr)
def forward(self, x, x_img):
x1 = self.conv1(x)
img = self.conv2(x) + x_img
x2 = F.sigmoid(self.conv3(img))
x1 = x1*x2
x1 = x1+x
return x1, img
@GENERATORS.register()
class MPRNet(nn.Layer):
def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias_attr=False):
super(MPRNet, self).__init__()
act=nn.PReLU()
self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias_attr=bias_attr), CAB(n_feat,kernel_size, reduction, bias_attr=bias_attr, act=act))
self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias_attr=bias_attr), CAB(n_feat,kernel_size, reduction, bias_attr=bias_attr, act=act))
self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias_attr=bias_attr), CAB(n_feat,kernel_size, reduction, bias_attr=bias_attr, act=act))
# Cross Stage Feature Fusion (CSFF)
self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats, csff=False)
self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats)
self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats, csff=True)
self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats)
self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias_attr, scale_unetfeats, num_cab)
self.sam12 = SAM(n_feat, kernel_size=1, bias_attr=bias_attr)
self.sam23 = SAM(n_feat, kernel_size=1, bias_attr=bias_attr)
self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias_attr=bias_attr)
self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias_attr=bias_attr)
self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias_attr=bias_attr)
def forward(self, x3_img):
# Original-resolution Image for Stage 3
H = x3_img.shape[2]
W = x3_img.shape[3]
# Multi-Patch Hierarchy: Split Image into four non-overlapping patches
# Two Patches for Stage 2
x2top_img = x3_img[:,:,0:int(H/2),:]
x2bot_img = x3_img[:,:,int(H/2):H,:]
# Four Patches for Stage 1
x1ltop_img = x2top_img[:,:,:,0:int(W/2)]
x1rtop_img = x2top_img[:,:,:,int(W/2):W]
x1lbot_img = x2bot_img[:,:,:,0:int(W/2)]
x1rbot_img = x2bot_img[:,:,:,int(W/2):W]
##-------------------------------------------
##-------------- Stage 1---------------------
##-------------------------------------------
## Compute Shallow Features
x1ltop = self.shallow_feat1(x1ltop_img)
x1rtop = self.shallow_feat1(x1rtop_img)
x1lbot = self.shallow_feat1(x1lbot_img)
x1rbot = self.shallow_feat1(x1rbot_img)
## Process features of all 4 patches with Encoder of Stage 1
feat1_ltop = self.stage1_encoder(x1ltop)
feat1_rtop = self.stage1_encoder(x1rtop)
feat1_lbot = self.stage1_encoder(x1lbot)
feat1_rbot = self.stage1_encoder(x1rbot)
## Concat deep features
feat1_top = [paddle.concat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)]
feat1_bot = [paddle.concat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)]
## Pass features through Decoder of Stage 1
res1_top = self.stage1_decoder(feat1_top)
res1_bot = self.stage1_decoder(feat1_bot)
## Apply Supervised Attention Module (SAM)
x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img)
x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img)
## Output image at Stage 1
stage1_img = paddle.concat([stage1_img_top, stage1_img_bot],2)
##-------------------------------------------
##-------------- Stage 2---------------------
##-------------------------------------------
## Compute Shallow Features
x2top = self.shallow_feat2(x2top_img)
x2bot = self.shallow_feat2(x2bot_img)
## Concatenate SAM features of Stage 1 with shallow features of Stage 2
x2top_cat = self.concat12(paddle.concat([x2top, x2top_samfeats], 1))
x2bot_cat = self.concat12(paddle.concat([x2bot, x2bot_samfeats], 1))
## Process features of both patches with Encoder of Stage 2
feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top)
feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot)
## Concat deep features
feat2 = [paddle.concat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)]
## Pass features through Decoder of Stage 2
res2 = self.stage2_decoder(feat2)
## Apply SAM
x3_samfeats, stage2_img = self.sam23(res2[0], x3_img)
##-------------------------------------------
##-------------- Stage 3---------------------
##-------------------------------------------
## Compute Shallow Features
x3 = self.shallow_feat3(x3_img)
## Concatenate SAM features of Stage 2 with shallow features of Stage 3
x3_cat = self.concat23(paddle.concat([x3, x3_samfeats], 1))
x3_cat = self.stage3_orsnet(x3_cat, feat2, res2)
stage3_img = self.tail(x3_cat)
return [stage3_img+x3_img, stage2_img, stage1_img]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .builder import MODELS
from .base_model import BaseModel
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from ..modules.init import reset_parameters, init_weights
@MODELS.register()
class MPRModel(BaseModel):
"""MPR Model.
Paper: MPR: Multi-Stage Progressive Image Restoration (CVPR 2021).
https://arxiv.org/abs/2102.02808
"""
def __init__(self, generator, char_criterion=None, edge_criterion=None):
"""Initialize the MPR class.
Args:
generator (dict): config of generator.
char_criterion (dict): config of char criterion.
edge_criterion (dict): config of edge criterion.
"""
super(MPRModel, self).__init__(generator)
self.current_iter = 1
self.nets['generator'] = build_generator(generator)
init_weights(self.nets['generator'])
if char_criterion:
self.char_criterion = build_criterion(char_criterion)
if edge_criterion:
self.edge_criterion = build_criterion(edge_criterion)
def setup_input(self, input):
self.target = input[0]
self.input_ = input[1]
def train_iter(self, optims=None):
optims['optim'].clear_gradients()
restored = self.nets['generator'](self.input_)
loss_char = []
loss_edge = []
for i in range(len(restored)):
loss_char.append(self.char_criterion(restored[i], self.target))
loss_edge.append(self.edge_criterion(restored[i], self.target))
loss_char = paddle.stack(loss_char)
loss_edge = paddle.stack(loss_edge)
loss_char = paddle.sum(loss_char)
loss_edge = paddle.sum(loss_edge)
loss = (loss_char) + (0.05 * loss_edge)
loss.backward()
optims['optim'].step()
self.losses['loss'] = loss.numpy()
def forward(self):
"""Run forward pass; called by both functions <train_iter> and <test_iter>."""
pass
def init_edvr_weight(net):
def reset_func(m):
if hasattr(m, 'weight') and (not isinstance(
m, (nn.BatchNorm, nn.BatchNorm2D))):
reset_parameters(m)
net.apply(reset_func)
...@@ -8,4 +8,4 @@ librosa==0.7.0 ...@@ -8,4 +8,4 @@ librosa==0.7.0
numba==0.48 numba==0.48
easydict easydict
munch munch
natsort
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册