From 7131e67c25236643a5cb7f3f379286d059241cd6 Mon Sep 17 00:00:00 2001 From: lyl120117 <278401555@qq.com> Date: Tue, 29 Jun 2021 21:49:37 +0800 Subject: [PATCH] add mpr deblur derain and denoise pretrained models for image restoration application (#352) * add mpr deblur derain and denoise pretrained models for image restoration application --- applications/tools/image_restoration.py | 54 ++++ configs/mprnet_deblurring.yaml | 66 +++++ ppgan/apps/__init__.py | 1 + ppgan/apps/mpr_predictor.py | 159 ++++++++++ ppgan/datasets/__init__.py | 1 + ppgan/datasets/mpr_dataset.py | 204 +++++++++++++ ppgan/models/__init__.py | 1 + ppgan/models/criterions/__init__.py | 2 +- ppgan/models/criterions/pixel_loss.py | 27 ++ ppgan/models/generators/__init__.py | 1 + ppgan/models/generators/mpr.py | 372 ++++++++++++++++++++++++ ppgan/models/mpr_model.py | 88 ++++++ requirements.txt | 2 +- 13 files changed, 976 insertions(+), 2 deletions(-) create mode 100644 applications/tools/image_restoration.py create mode 100644 configs/mprnet_deblurring.yaml create mode 100644 ppgan/apps/mpr_predictor.py create mode 100644 ppgan/datasets/mpr_dataset.py create mode 100644 ppgan/models/generators/mpr.py create mode 100644 ppgan/models/mpr_model.py diff --git a/applications/tools/image_restoration.py b/applications/tools/image_restoration.py new file mode 100644 index 0000000..bc409c6 --- /dev/null +++ b/applications/tools/image_restoration.py @@ -0,0 +1,54 @@ +import paddle +import os +import sys +sys.path.insert(0, os.getcwd()) +from ppgan.apps import MPRPredictor +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_path", + type=str, + default='output_dir', + help="path to output image dir") + + parser.add_argument("--weight_path", + type=str, + default=None, + help="path to model checkpoint path") + + parser.add_argument("--seed", + type=int, + default=None, + help="sample random seed for model's image generation") + + parser.add_argument('--images_path', + default=None, + required=True, + type=str, + help='Single image or images directory.') + + parser.add_argument('--task', + required=True, + type=str, + help='Task to run', + choices=['Deblurring', 'Denoising', 'Deraining']) + + parser.add_argument("--cpu", + dest="cpu", + action="store_true", + help="cpu mode.") + + args = parser.parse_args() + + if args.cpu: + paddle.set_device('cpu') + + predictor = MPRPredictor( + images_path=args.images_path, + output_path=args.output_path, + weight_path=args.weight_path, + seed=args.seed, + task=args.task + ) + predictor.run() diff --git a/configs/mprnet_deblurring.yaml b/configs/mprnet_deblurring.yaml new file mode 100644 index 0000000..a69fb02 --- /dev/null +++ b/configs/mprnet_deblurring.yaml @@ -0,0 +1,66 @@ +total_iters: 100000 +output_dir: output_dir + +model: + name: MPRModel + generator: + name: MPRNet + + char_criterion: + name: CharbonnierLoss + edge_criterion: + name: EdgeLoss + +dataset: + train: + name: MPRTrain + rgb_dir: 'data/GoPro/train' + num_workers: 16 + batch_size: 4 + img_options: + patch_size: 256 + test: + name: MPRTrain + rgb_dir: 'data/GoPro/test' + num_workers: 16 + batch_size: 4 + img_options: + patch_size: 256 + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 2e-4 + periods: [25000, 25000, 25000, 25000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-6 + +validate: + interval: 10 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 4 + test_y_channel: True + ssim: + name: SSIM + crop_border: 4 + test_y_channel: True + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + +log_config: + interval: 10 + visiual_interval: 5000 + +snapshot_config: + interval: 5000 diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index 647a212..f0bfc1c 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -25,3 +25,4 @@ from .photo2cartoon_predictor import Photo2CartoonPredictor from .styleganv2_predictor import StyleGANv2Predictor from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor from .wav2lip_predictor import Wav2LipPredictor +from .mpr_predictor import MPRPredictor diff --git a/ppgan/apps/mpr_predictor.py b/ppgan/apps/mpr_predictor.py new file mode 100644 index 0000000..e76cc47 --- /dev/null +++ b/ppgan/apps/mpr_predictor.py @@ -0,0 +1,159 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +from natsort import natsorted +from glob import glob +import numpy as np +import cv2 +from PIL import Image +import paddle +from .base_predictor import BasePredictor +from ppgan.models.generators import MPRNet +from ppgan.utils.download import get_path_from_url +from ppgan.utils.visual import make_grid, tensor2img, save_image +from ppgan.datasets.mpr_dataset import to_tensor +from paddle.vision.transforms import Pad +from tqdm import tqdm + +model_cfgs = { + 'Deblurring': { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams', + 'n_feat': 96, + 'scale_unetfeats': 48, + 'scale_orsnetfeats': 32, + }, + 'Denoising': { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/MPR_Denoising.pdparams', + 'n_feat': 80, + 'scale_unetfeats': 48, + 'scale_orsnetfeats': 32, + }, + 'Deraining': { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/MPR_Deraining.pdparams', + 'n_feat': 40, + 'scale_unetfeats': 20, + 'scale_orsnetfeats': 16, + } +} + + +class MPRPredictor(BasePredictor): + def __init__(self, + images_path=None, + output_path='output_dir', + weight_path=None, + seed=None, + task=None): + self.output_path = output_path + self.images_path = images_path + self.task = task + self.max_size = 640 + self.img_multiple_of = 8 + + if weight_path is None: + if task in model_cfgs.keys(): + weight_path = get_path_from_url(model_cfgs[task]['model_urls']) + checkpoint = paddle.load(weight_path) + else: + raise ValueError( + 'Predictor need a weight path or a pretrained model type') + else: + checkpoint = paddle.load(weight_path) + + self.generator = MPRNet( + n_feat=model_cfgs[task]['n_feat'], + scale_unetfeats=model_cfgs[task]['scale_unetfeats'], + scale_orsnetfeats=model_cfgs[task]['scale_orsnetfeats']) + self.generator.set_state_dict(checkpoint) + self.generator.eval() + + if seed is not None: + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + def get_images(self, images_path): + if os.path.isdir(images_path): + return natsorted( + glob(os.path.join(images_path, '*.jpg')) + + glob(os.path.join(images_path, '*.JPG')) + + glob(os.path.join(images_path, '*.png')) + + glob(os.path.join(images_path, '*.PNG'))) + else: + return [images_path] + + def read_image(self, image_file): + img = Image.open(image_file).convert('RGB') + max_length = max(img.width, img.height) + if max_length > self.max_size: + ratio = max_length / self.max_size + dw = int(img.width / ratio) + dh = int(img.height / ratio) + img = img.resize((dw, dh)) + return img + + def run(self): + os.makedirs(self.output_path, exist_ok=True) + task_path = os.path.join(self.output_path, self.task) + os.makedirs(task_path, exist_ok=True) + image_files = self.get_images(self.images_path) + for image_file in tqdm(image_files): + img = self.read_image(image_file) + image_name = os.path.basename(image_file) + img.save(os.path.join(task_path, image_name)) + tmps = image_name.split('.') + assert len( + tmps) == 2, f'Invalid image name: {image_name}, too much "."' + restoration_save_path = os.path.join( + task_path, f'{tmps[0]}_restoration.{tmps[1]}') + input_ = to_tensor(img) + + # Pad the input if not_multiple_of 8 + h, w = input_.shape[1], input_.shape[2] + + H, W = ((h + self.img_multiple_of) // + self.img_multiple_of) * self.img_multiple_of, ( + (w + self.img_multiple_of) // + self.img_multiple_of) * self.img_multiple_of + padh = H - h if h % self.img_multiple_of != 0 else 0 + padw = W - w if w % self.img_multiple_of != 0 else 0 + input_ = paddle.to_tensor(input_) + transform = Pad((0, 0, padw, padh), padding_mode='reflect') + input_ = transform(input_) + + input_ = paddle.to_tensor(np.expand_dims(input_.numpy(), 0)) + + with paddle.no_grad(): + restored = self.generator(input_) + restored = restored[0] + restored = paddle.clip(restored, 0, 1) + + # Unpad the output + restored = restored[:, :, :h, :w] + + restored = restored.numpy() + restored = restored.transpose(0, 2, 3, 1) + restored = restored[0] + restored = restored * 255 + restored = restored.astype(np.uint8) + + cv2.imwrite(restoration_save_path, + cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) + + print('Done, output path is:', task_path) diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index 843fd74..ef9d7ad 100755 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -24,3 +24,4 @@ from .starganv2_dataset import StarGANv2Dataset from .edvr_dataset import REDSDataset from .firstorder_dataset import FirstOrderDataset from .lapstyle_dataset import LapStyleDataset +from .mpr_dataset import MPRTrain, MPRVal, MPRTest diff --git a/ppgan/datasets/mpr_dataset.py b/ppgan/datasets/mpr_dataset.py new file mode 100644 index 0000000..8c243cb --- /dev/null +++ b/ppgan/datasets/mpr_dataset.py @@ -0,0 +1,204 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import numpy as np +import cv2 +import paddle +from PIL import Image, ImageEnhance +import numpy as np +import random +import numbers +from paddle.io import Dataset +from .builder import DATASETS +from paddle.vision.transforms.functional import to_tensor, adjust_brightness, adjust_saturation, rotate, hflip, hflip, vflip, center_crop + + +def is_image_file(filename): + return any( + filename.endswith(extension) + for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) + + +@DATASETS.register() +class MPRTrain(Dataset): + def __init__(self, rgb_dir, img_options=None): + super(MPRTrain, self).__init__() + + inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) + tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) + + self.inp_filenames = [ + os.path.join(rgb_dir, 'input', x) for x in inp_files + if is_image_file(x) + ] + self.tar_filenames = [ + os.path.join(rgb_dir, 'target', x) for x in tar_files + if is_image_file(x) + ] + + self.img_options = img_options + self.sizex = len(self.tar_filenames) # get the size of target + + self.ps = self.img_options['patch_size'] + + def __len__(self): + return self.sizex + + def __getitem__(self, index): + index_ = index % self.sizex + ps = self.ps + + inp_path = self.inp_filenames[index_] + tar_path = self.tar_filenames[index_] + + inp_img = Image.open(inp_path) + tar_img = Image.open(tar_path) + + w, h = tar_img.size + padw = ps - w if w < ps else 0 + padh = ps - h if h < ps else 0 + + # Reflect Pad in case image is smaller than patch_size + if padw != 0 or padh != 0: + inp_img = np.pad(inp_img, (0, 0, padw, padh), + padding_mode='reflect') + tar_img = np.pad(tar_img, (0, 0, padw, padh), + padding_mode='reflect') + + aug = random.randint(0, 2) + if aug == 1: + inp_img = adjust_brightness(inp_img, 1) + tar_img = adjust_brightness(tar_img, 1) + + aug = random.randint(0, 2) + if aug == 1: + sat_factor = 1 + (0.2 - 0.4 * np.random.rand()) + inp_img = adjust_saturation(inp_img, sat_factor) + tar_img = adjust_saturation(tar_img, sat_factor) + + # Data Augmentations + if aug == 1: + inp_img = vflip(inp_img) + tar_img = vflip(tar_img) + elif aug == 2: + inp_img = hflip(inp_img) + tar_img = hflip(tar_img) + elif aug == 3: + inp_img = rotate(inp_img, 90) + tar_img = rotate(tar_img, 90) + elif aug == 4: + inp_img = rotate(inp_img, 90 * 2) + tar_img = rotate(tar_img, 90 * 2) + elif aug == 5: + inp_img = rotate(inp_img, 90 * 3) + tar_img = rotate(tar_img, 90 * 3) + elif aug == 6: + inp_img = rotate(vflip(inp_img), 90) + tar_img = rotate(vflip(tar_img), 90) + elif aug == 7: + inp_img = rotate(hflip(inp_img), 90) + tar_img = rotate(hflip(tar_img), 90) + + inp_img = to_tensor(inp_img) + tar_img = to_tensor(tar_img) + + hh, ww = tar_img.shape[1], tar_img.shape[2] + + rr = random.randint(0, hh - ps) + cc = random.randint(0, ww - ps) + aug = random.randint(0, 8) + + # Crop patch + inp_img = inp_img[:, rr:rr + ps, cc:cc + ps] + tar_img = tar_img[:, rr:rr + ps, cc:cc + ps] + + filename = os.path.splitext(os.path.split(tar_path)[-1])[0] + + return tar_img, inp_img, filename + + +@DATASETS.register() +class MPRVal(Dataset): + def __init__(self, rgb_dir, img_options=None, rgb_dir2=None): + super(MPRVal, self).__init__() + + inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) + tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) + + self.inp_filenames = [ + os.path.join(rgb_dir, 'input', x) for x in inp_files + if is_image_file(x) + ] + self.tar_filenames = [ + os.path.join(rgb_dir, 'target', x) for x in tar_files + if is_image_file(x) + ] + + self.img_options = img_options + self.sizex = len(self.tar_filenames) # get the size of target + + self.ps = self.img_options['patch_size'] + + def __len__(self): + return self.sizex + + def __getitem__(self, index): + index_ = index % self.sizex + ps = self.ps + + inp_path = self.inp_filenames[index_] + tar_path = self.tar_filenames[index_] + + inp_img = Image.open(inp_path) + tar_img = Image.open(tar_path) + + # Validate on center crop + if self.ps is not None: + inp_img = center_crop(inp_img, (ps, ps)) + tar_img = center_crop(tar_img, (ps, ps)) + + inp_img = to_tensor(inp_img) + tar_img = to_tensor(tar_img) + + filename = os.path.splitext(os.path.split(tar_path)[-1])[0] + + return tar_img, inp_img, filename + + +@DATASETS.register() +class MPRTest(Dataset): + def __init__(self, inp_dir, img_options): + super(MPRTest, self).__init__() + + inp_files = sorted(os.listdir(inp_dir)) + self.inp_filenames = [ + os.path.join(inp_dir, x) for x in inp_files if is_image_file(x) + ] + + self.inp_size = len(self.inp_filenames) + self.img_options = img_options + + def __len__(self): + return self.inp_size + + def __getitem__(self, index): + + path_inp = self.inp_filenames[index] + filename = os.path.splitext(os.path.split(path_inp)[-1])[0] + inp = Image.open(path_inp) + + inp = to_tensor(inp) + return inp, filename diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 4367f4d..31cf00d 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -30,3 +30,4 @@ from .starganv2_model import StarGANv2Model from .edvr_model import EDVRModel from .firstorder_model import FirstOrderModel from .lapstyle_model import LapStyleDraModel, LapStyleRevFirstModel, LapStyleRevSecondModel +from .mpr_model import MPRModel diff --git a/ppgan/models/criterions/__init__.py b/ppgan/models/criterions/__init__.py index cd760e7..b4ab040 100644 --- a/ppgan/models/criterions/__init__.py +++ b/ppgan/models/criterions/__init__.py @@ -2,6 +2,6 @@ from .gan_loss import GANLoss from .perceptual_loss import PerceptualLoss from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss, \ CalcStyleEmdLoss, CalcContentReltLoss, \ - CalcContentLoss, CalcStyleLoss + CalcContentLoss, CalcStyleLoss, EdgeLoss from .builder import build_criterion diff --git a/ppgan/models/criterions/pixel_loss.py b/ppgan/models/criterions/pixel_loss.py index 11df70f..f33e632 100644 --- a/ppgan/models/criterions/pixel_loss.py +++ b/ppgan/models/criterions/pixel_loss.py @@ -17,6 +17,7 @@ from ..generators.generater_lapstyle import calc_mean_std, mean_variance_norm import paddle import paddle.nn as nn +import paddle.nn.functional as F from .builder import CRITERIONS @@ -234,3 +235,29 @@ class CalcStyleLoss(): target_mean, target_std = calc_mean_std(target) return self.mse_loss(pred_mean, target_mean) + self.mse_loss( pred_std, target_std) + + +@CRITERIONS.register() +class EdgeLoss(): + def __init__(self): + k = paddle.to_tensor([[.05, .25, .4, .25, .05]]) + self.kernel = paddle.matmul(k.t(),k).unsqueeze(0).tile([3,1,1,1]) + self.loss = CharbonnierLoss() + + def conv_gauss(self, img): + n_channels, _, kw, kh = self.kernel.shape + img = F.pad(img, [kw//2, kh//2, kw//2, kh//2], mode='replicate') + return F.conv2d(img, self.kernel, groups=n_channels) + + def laplacian_kernel(self, current): + filtered = self.conv_gauss(current) # filter + down = filtered[:,:,::2,::2] # downsample + new_filter = paddle.zeros_like(filtered) + new_filter[:,:,::2,::2] = down*4 # upsample + filtered = self.conv_gauss(new_filter) # filter + diff = current - filtered + return diff + + def __call__(self, x, y): + loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) + return loss \ No newline at end of file diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index d432500..8e19310 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -30,3 +30,4 @@ from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Ma from .edvr import EDVRNet from .generator_firstorder import FirstOrderGenerator from .generater_lapstyle import DecoderNet, Encoder, RevisionNet +from .mpr import MPRNet diff --git a/ppgan/models/generators/mpr.py b/ppgan/models/generators/mpr.py new file mode 100644 index 0000000..53bb7b5 --- /dev/null +++ b/ppgan/models/generators/mpr.py @@ -0,0 +1,372 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +import numpy as np + +import paddle.nn as nn +from ...modules.init import kaiming_normal_, constant_ + +from ...modules.dcn import DeformableConv_dygraph +# from paddle.vision.ops import DeformConv2D #to be compiled + +from .builder import GENERATORS + +import paddle +from paddle import nn +import paddle.nn.functional as F + + +########################################################################## +def conv(in_channels, out_channels, kernel_size, bias_attr=False, stride=1): + return nn.Conv2D( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias_attr=bias_attr, stride = stride) + + +########################################################################## +## Channel Attention Layer +class CALayer(nn.Layer): + def __init__(self, channel, reduction=16, bias_attr=False): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2D(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2D(channel, channel // reduction, 1, padding=0, bias_attr=bias_attr), + nn.ReLU(), + # nn.ReLU(inplace=True), torch + nn.Conv2D(channel // reduction, channel, 1, padding=0, bias_attr=bias_attr), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +########################################################################## +## Channel Attention Block (CAB) +class CAB(nn.Layer): + def __init__(self, n_feat, kernel_size, reduction, bias_attr, act): + super(CAB, self).__init__() + modules_body = [] + modules_body.append(conv(n_feat, n_feat, kernel_size, bias_attr=bias_attr)) + modules_body.append(act) + modules_body.append(conv(n_feat, n_feat, kernel_size, bias_attr=bias_attr)) + + self.CA = CALayer(n_feat, reduction, bias_attr=bias_attr) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res = self.CA(res) + res += x + return res + + +########################################################################## +##---------- Resizing Modules ---------- +class DownSample(nn.Layer): + def __init__(self, in_channels,s_factor): + super(DownSample, self).__init__() + self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), + nn.Conv2D(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias_attr=False)) + + def forward(self, x): + x = self.down(x) + return x + +class UpSample(nn.Layer): + def __init__(self, in_channels,s_factor): + super(UpSample, self).__init__() + self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + nn.Conv2D(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias_attr=False)) + + def forward(self, x): + x = self.up(x) + return x + + +class SkipUpSample(nn.Layer): + def __init__(self, in_channels,s_factor): + super(SkipUpSample, self).__init__() + self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + nn.Conv2D(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias_attr=False)) + + def forward(self, x, y): + x = self.up(x) + x = x + y + return x + +########################################################################## +## U-Net +class Encoder(nn.Layer): + def __init__(self, n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats, csff): + super(Encoder, self).__init__() + + self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)] + self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)] + self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)] + + self.encoder_level1 = nn.Sequential(*self.encoder_level1) + self.encoder_level2 = nn.Sequential(*self.encoder_level2) + self.encoder_level3 = nn.Sequential(*self.encoder_level3) + + self.down12 = DownSample(n_feat, scale_unetfeats) + self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats) + + # Cross Stage Feature Fusion (CSFF) + if csff: + self.csff_enc1 = nn.Conv2D(n_feat, n_feat, kernel_size=1, bias_attr=bias_attr) + self.csff_enc2 = nn.Conv2D(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias_attr=bias_attr) + self.csff_enc3 = nn.Conv2D(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias_attr=bias_attr) + + self.csff_dec1 = nn.Conv2D(n_feat, n_feat, kernel_size=1, bias_attr=bias_attr) + self.csff_dec2 = nn.Conv2D(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias_attr=bias_attr) + self.csff_dec3 = nn.Conv2D(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias_attr=bias_attr) + + def forward(self, x, encoder_outs=None, decoder_outs=None): + enc1 = self.encoder_level1(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) + + x = self.down12(enc1) + + enc2 = self.encoder_level2(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) + + x = self.down23(enc2) + + enc3 = self.encoder_level3(x) + if (encoder_outs is not None) and (decoder_outs is not None): + enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) + + return [enc1, enc2, enc3] + + +class Decoder(nn.Layer): + def __init__(self, n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats): + super(Decoder, self).__init__() + + self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)] + self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)] + self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(2)] + + self.decoder_level1 = nn.Sequential(*self.decoder_level1) + self.decoder_level2 = nn.Sequential(*self.decoder_level2) + self.decoder_level3 = nn.Sequential(*self.decoder_level3) + + self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias_attr=bias_attr, act=act) + self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias_attr=bias_attr, act=act) + + self.up21 = SkipUpSample(n_feat, scale_unetfeats) + self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats) + + def forward(self, outs): + enc1, enc2, enc3 = outs + dec3 = self.decoder_level3(enc3) + + x = self.up32(dec3, self.skip_attn2(enc2)) + dec2 = self.decoder_level2(x) + + x = self.up21(dec2, self.skip_attn1(enc1)) + dec1 = self.decoder_level1(x) + + return [dec1,dec2,dec3] + + +########################################################################## +## Original Resolution Block (ORB) +class ORB(nn.Layer): + def __init__(self, n_feat, kernel_size, reduction, act, bias_attr, num_cab): + super(ORB, self).__init__() + modules_body = [] + modules_body = [CAB(n_feat, kernel_size, reduction, bias_attr=bias_attr, act=act) for _ in range(num_cab)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +########################################################################## +class ORSNet(nn.Layer): + def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias_attr, scale_unetfeats, num_cab): + super(ORSNet, self).__init__() + + self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias_attr, num_cab) + self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias_attr, num_cab) + self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias_attr, num_cab) + + self.up_enc1 = UpSample(n_feat, scale_unetfeats) + self.up_dec1 = UpSample(n_feat, scale_unetfeats) + + self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) + self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) + + self.conv_enc1 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr) + self.conv_enc2 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr) + self.conv_enc3 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr) + + self.conv_dec1 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr) + self.conv_dec2 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr) + self.conv_dec3 = nn.Conv2D(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias_attr=bias_attr) + + def forward(self, x, encoder_outs, decoder_outs): + x = self.orb1(x) + x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) + + x = self.orb2(x) + x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1])) + + x = self.orb3(x) + x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2])) + + return x + + +########################################################################## +## Supervised Attention Module +class SAM(nn.Layer): + def __init__(self, n_feat, kernel_size, bias_attr): + super(SAM, self).__init__() + self.conv1 = conv(n_feat, n_feat, kernel_size, bias_attr=bias_attr) + self.conv2 = conv(n_feat, 3, kernel_size, bias_attr=bias_attr) + self.conv3 = conv(3, n_feat, kernel_size, bias_attr=bias_attr) + + def forward(self, x, x_img): + x1 = self.conv1(x) + img = self.conv2(x) + x_img + x2 = F.sigmoid(self.conv3(img)) + x1 = x1*x2 + x1 = x1+x + return x1, img + + +@GENERATORS.register() +class MPRNet(nn.Layer): + def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias_attr=False): + super(MPRNet, self).__init__() + act=nn.PReLU() + self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias_attr=bias_attr), CAB(n_feat,kernel_size, reduction, bias_attr=bias_attr, act=act)) + self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias_attr=bias_attr), CAB(n_feat,kernel_size, reduction, bias_attr=bias_attr, act=act)) + self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias_attr=bias_attr), CAB(n_feat,kernel_size, reduction, bias_attr=bias_attr, act=act)) + + # Cross Stage Feature Fusion (CSFF) + self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats, csff=False) + self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats) + + self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats, csff=True) + self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias_attr, scale_unetfeats) + + self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias_attr, scale_unetfeats, num_cab) + + self.sam12 = SAM(n_feat, kernel_size=1, bias_attr=bias_attr) + self.sam23 = SAM(n_feat, kernel_size=1, bias_attr=bias_attr) + + self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias_attr=bias_attr) + self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias_attr=bias_attr) + self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias_attr=bias_attr) + + def forward(self, x3_img): + # Original-resolution Image for Stage 3 + H = x3_img.shape[2] + W = x3_img.shape[3] + + # Multi-Patch Hierarchy: Split Image into four non-overlapping patches + + # Two Patches for Stage 2 + x2top_img = x3_img[:,:,0:int(H/2),:] + x2bot_img = x3_img[:,:,int(H/2):H,:] + + # Four Patches for Stage 1 + x1ltop_img = x2top_img[:,:,:,0:int(W/2)] + x1rtop_img = x2top_img[:,:,:,int(W/2):W] + x1lbot_img = x2bot_img[:,:,:,0:int(W/2)] + x1rbot_img = x2bot_img[:,:,:,int(W/2):W] + + ##------------------------------------------- + ##-------------- Stage 1--------------------- + ##------------------------------------------- + ## Compute Shallow Features + x1ltop = self.shallow_feat1(x1ltop_img) + x1rtop = self.shallow_feat1(x1rtop_img) + x1lbot = self.shallow_feat1(x1lbot_img) + x1rbot = self.shallow_feat1(x1rbot_img) + + ## Process features of all 4 patches with Encoder of Stage 1 + feat1_ltop = self.stage1_encoder(x1ltop) + feat1_rtop = self.stage1_encoder(x1rtop) + feat1_lbot = self.stage1_encoder(x1lbot) + feat1_rbot = self.stage1_encoder(x1rbot) + + ## Concat deep features + feat1_top = [paddle.concat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)] + feat1_bot = [paddle.concat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)] + + ## Pass features through Decoder of Stage 1 + res1_top = self.stage1_decoder(feat1_top) + res1_bot = self.stage1_decoder(feat1_bot) + + ## Apply Supervised Attention Module (SAM) + x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) + x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) + + ## Output image at Stage 1 + stage1_img = paddle.concat([stage1_img_top, stage1_img_bot],2) + ##------------------------------------------- + ##-------------- Stage 2--------------------- + ##------------------------------------------- + ## Compute Shallow Features + x2top = self.shallow_feat2(x2top_img) + x2bot = self.shallow_feat2(x2bot_img) + + ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 + x2top_cat = self.concat12(paddle.concat([x2top, x2top_samfeats], 1)) + x2bot_cat = self.concat12(paddle.concat([x2bot, x2bot_samfeats], 1)) + + ## Process features of both patches with Encoder of Stage 2 + feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) + feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) + + ## Concat deep features + feat2 = [paddle.concat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)] + + ## Pass features through Decoder of Stage 2 + res2 = self.stage2_decoder(feat2) + + ## Apply SAM + x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) + + + ##------------------------------------------- + ##-------------- Stage 3--------------------- + ##------------------------------------------- + ## Compute Shallow Features + x3 = self.shallow_feat3(x3_img) + + ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 + x3_cat = self.concat23(paddle.concat([x3, x3_samfeats], 1)) + + x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) + + stage3_img = self.tail(x3_cat) + + return [stage3_img+x3_img, stage2_img, stage1_img] diff --git a/ppgan/models/mpr_model.py b/ppgan/models/mpr_model.py new file mode 100644 index 0000000..1c2c7cb --- /dev/null +++ b/ppgan/models/mpr_model.py @@ -0,0 +1,88 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .builder import MODELS +from .base_model import BaseModel +from .generators.builder import build_generator +from .criterions.builder import build_criterion +from ..modules.init import reset_parameters, init_weights + + +@MODELS.register() +class MPRModel(BaseModel): + """MPR Model. + + Paper: MPR: Multi-Stage Progressive Image Restoration (CVPR 2021). + https://arxiv.org/abs/2102.02808 + """ + def __init__(self, generator, char_criterion=None, edge_criterion=None): + """Initialize the MPR class. + + Args: + generator (dict): config of generator. + char_criterion (dict): config of char criterion. + edge_criterion (dict): config of edge criterion. + """ + super(MPRModel, self).__init__(generator) + self.current_iter = 1 + + self.nets['generator'] = build_generator(generator) + init_weights(self.nets['generator']) + + if char_criterion: + self.char_criterion = build_criterion(char_criterion) + if edge_criterion: + self.edge_criterion = build_criterion(edge_criterion) + + def setup_input(self, input): + self.target = input[0] + self.input_ = input[1] + + def train_iter(self, optims=None): + optims['optim'].clear_gradients() + + restored = self.nets['generator'](self.input_) + + loss_char = [] + loss_edge = [] + + for i in range(len(restored)): + loss_char.append(self.char_criterion(restored[i], self.target)) + loss_edge.append(self.edge_criterion(restored[i], self.target)) + loss_char = paddle.stack(loss_char) + loss_edge = paddle.stack(loss_edge) + loss_char = paddle.sum(loss_char) + loss_edge = paddle.sum(loss_edge) + + loss = (loss_char) + (0.05 * loss_edge) + + loss.backward() + optims['optim'].step() + self.losses['loss'] = loss.numpy() + + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + +def init_edvr_weight(net): + def reset_func(m): + if hasattr(m, 'weight') and (not isinstance( + m, (nn.BatchNorm, nn.BatchNorm2D))): + reset_parameters(m) + + net.apply(reset_func) diff --git a/requirements.txt b/requirements.txt index 91dc23a..a6129f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ librosa==0.7.0 numba==0.48 easydict munch - +natsort -- GitLab