# code was heavily based on https://github.com/swz30/MPRNet # Users should be careful about adopting these functions in any commercial matters. # https://github.com/swz30/MPRNet/blob/main/LICENSE.md import os import random import numpy as np from PIL import Image from paddle.io import Dataset from .builder import DATASETS from paddle.vision.transforms.functional import to_tensor, adjust_brightness, adjust_saturation, rotate, hflip, hflip, vflip, center_crop def is_image_file(filename): return any( filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) @DATASETS.register() class NAFNetTrain(Dataset): def __init__(self, rgb_dir, img_options=None): super(NAFNetTrain, self).__init__() inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) self.inp_filenames = [ os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x) ] self.tar_filenames = [ os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x) ] self.img_options = img_options self.sizex = len(self.tar_filenames) # get the size of target self.ps = self.img_options['patch_size'] def __len__(self): return self.sizex def __getitem__(self, index): index_ = index % self.sizex ps = self.ps inp_path = self.inp_filenames[index_] tar_path = self.tar_filenames[index_] inp_img = Image.open(inp_path) tar_img = Image.open(tar_path) w, h = tar_img.size padw = ps - w if w < ps else 0 padh = ps - h if h < ps else 0 # Reflect Pad in case image is smaller than patch_size if padw != 0 or padh != 0: inp_img = np.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect') tar_img = np.pad(tar_img, (0, 0, padw, padh), padding_mode='reflect') aug = random.randint(0, 2) if aug == 1: inp_img = adjust_brightness(inp_img, 1) tar_img = adjust_brightness(tar_img, 1) aug = random.randint(0, 2) if aug == 1: sat_factor = 1 + (0.2 - 0.4 * np.random.rand()) inp_img = adjust_saturation(inp_img, sat_factor) tar_img = adjust_saturation(tar_img, sat_factor) # Data Augmentations aug = random.randint(0, 8) if aug == 1: inp_img = vflip(inp_img) tar_img = vflip(tar_img) elif aug == 2: inp_img = hflip(inp_img) tar_img = hflip(tar_img) elif aug == 3: inp_img = rotate(inp_img, 90) tar_img = rotate(tar_img, 90) elif aug == 4: inp_img = rotate(inp_img, 90 * 2) tar_img = rotate(tar_img, 90 * 2) elif aug == 5: inp_img = rotate(inp_img, 90 * 3) tar_img = rotate(tar_img, 90 * 3) elif aug == 6: inp_img = rotate(vflip(inp_img), 90) tar_img = rotate(vflip(tar_img), 90) elif aug == 7: inp_img = rotate(hflip(inp_img), 90) tar_img = rotate(hflip(tar_img), 90) inp_img = to_tensor(inp_img) tar_img = to_tensor(tar_img) hh, ww = tar_img.shape[1], tar_img.shape[2] rr = random.randint(0, hh - ps) cc = random.randint(0, ww - ps) # Crop patch inp_img = inp_img[:, rr:rr + ps, cc:cc + ps] tar_img = tar_img[:, rr:rr + ps, cc:cc + ps] filename = os.path.splitext(os.path.split(tar_path)[-1])[0] return tar_img, inp_img, filename @DATASETS.register() class NAFNetVal(Dataset): def __init__(self, rgb_dir, img_options=None, rgb_dir2=None): super(NAFNetVal, self).__init__() inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) self.inp_filenames = [ os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x) ] self.tar_filenames = [ os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x) ] self.img_options = img_options self.sizex = len(self.tar_filenames) # get the size of target self.ps = self.img_options['patch_size'] def __len__(self): return self.sizex def __getitem__(self, index): index_ = index % self.sizex ps = self.ps inp_path = self.inp_filenames[index_] tar_path = self.tar_filenames[index_] inp_img = Image.open(inp_path) tar_img = Image.open(tar_path) # Validate on center crop if self.ps is not None: inp_img = center_crop(inp_img, (ps, ps)) tar_img = center_crop(tar_img, (ps, ps)) inp_img = to_tensor(inp_img) tar_img = to_tensor(tar_img) filename = os.path.splitext(os.path.split(tar_path)[-1])[0] return tar_img, inp_img, filename @DATASETS.register() class NAFNetTest(Dataset): def __init__(self, inp_dir, img_options): super(NAFNetTest, self).__init__() inp_files = sorted(os.listdir(inp_dir)) self.inp_filenames = [ os.path.join(inp_dir, x) for x in inp_files if is_image_file(x) ] self.inp_size = len(self.inp_filenames) self.img_options = img_options def __len__(self): return self.inp_size def __getitem__(self, index): path_inp = self.inp_filenames[index] filename = os.path.splitext(os.path.split(path_inp)[-1])[0] inp = Image.open(path_inp) inp = to_tensor(inp) return inp, filename