# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import mmcv import os import cv2 import random import numpy as np import paddle.vision.transforms as transform from pathlib import Path from paddle.io import Dataset from .builder import DATASETS def scandir(dir_path, suffix=None, recursive=False): """Scan a directory to find the interested files. """ if isinstance(dir_path, (str, Path)): dir_path = str(dir_path) else: raise TypeError('"dir_path" must be a string or Path object') if (suffix is not None) and not isinstance(suffix, (str, tuple)): raise TypeError('"suffix" must be a string or tuple of strings') root = dir_path def _scandir(dir_path, suffix, recursive): for entry in os.scandir(dir_path): if not entry.name.startswith('.') and entry.is_file(): rel_path = os.path.relpath(entry.path, root) if suffix is None: yield rel_path elif rel_path.endswith(suffix): yield rel_path else: if recursive: yield from _scandir(entry.path, suffix=suffix, recursive=recursive) else: continue return _scandir(dir_path, suffix=suffix, recursive=recursive) def paired_paths_from_folder(folders, keys, filename_tmpl): """Generate paired paths from folders. """ assert len(folders) == 2, ( 'The len of folders should be 2 with [input_folder, gt_folder]. ' f'But got {len(folders)}') assert len(keys) == 2, ( 'The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') input_folder, gt_folder = folders input_key, gt_key = keys input_paths = list(scandir(input_folder)) gt_paths = list(scandir(gt_folder)) assert len(input_paths) == len(gt_paths), ( f'{input_key} and {gt_key} datasets have different number of images: ' f'{len(input_paths)}, {len(gt_paths)}.') paths = [] for gt_path in gt_paths: basename, ext = os.path.splitext(os.path.basename(gt_path)) input_name = f'{filename_tmpl.format(basename)}{ext}' input_path = os.path.join(input_folder, input_name) assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.') gt_path = os.path.join(gt_folder, gt_path) paths.append( dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) return paths def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): """Paired random crop. It crops lists of lq and gt images with corresponding locations. Args: img_gts (list[ndarray] | ndarray): GT images. Note that all images should have the same shape. If the input is an ndarray, it will be transformed to a list containing itself. img_lqs (list[ndarray] | ndarray): LQ images. Note that all images should have the same shape. If the input is an ndarray, it will be transformed to a list containing itself. gt_patch_size (int): GT patch size. scale (int): Scale factor. gt_path (str): Path to ground-truth. Returns: list[ndarray] | ndarray: GT images and LQ images. If returned results only have one element, just return ndarray. """ if not isinstance(img_gts, list): img_gts = [img_gts] if not isinstance(img_lqs, list): img_lqs = [img_lqs] h_lq, w_lq, _ = img_lqs[0].shape h_gt, w_gt, _ = img_gts[0].shape lq_patch_size = gt_patch_size // scale if h_gt != h_lq * scale or w_gt != w_lq * scale: raise ValueError( f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', f'multiplication of LQ ({h_lq}, {w_lq}).') if h_lq < lq_patch_size or w_lq < lq_patch_size: raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' f'({lq_patch_size}, {lq_patch_size}). ' f'Please remove {gt_path}.') # randomly choose top and left coordinates for lq patch top = random.randint(0, h_lq - lq_patch_size) left = random.randint(0, w_lq - lq_patch_size) # crop lq patch img_lqs = [ v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs ] # crop corresponding gt patch top_gt, left_gt = int(top * scale), int(left * scale) img_gts = [ v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts ] if len(img_gts) == 1: img_gts = img_gts[0] if len(img_lqs) == 1: img_lqs = img_lqs[0] return img_gts, img_lqs def augment(imgs, hflip=True, rotation=True, flows=None): """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). """ hflip = hflip and random.random() < 0.5 vflip = rotation and random.random() < 0.5 rot90 = rotation and random.random() < 0.5 def _augment(img): if hflip: cv2.flip(img, 1, img) if vflip: cv2.flip(img, 0, img) if rot90: img = img.transpose(1, 0, 2) return img def _augment_flow(flow): if hflip: cv2.flip(flow, 1, flow) flow[:, :, 0] *= -1 if vflip: cv2.flip(flow, 0, flow) flow[:, :, 1] *= -1 if rot90: flow = flow.transpose(1, 0, 2) flow = flow[:, :, [1, 0]] return flow if not isinstance(imgs, list): imgs = [imgs] imgs = [_augment(img) for img in imgs] if len(imgs) == 1: imgs = imgs[0] if flows is not None: if not isinstance(flows, list): flows = [flows] flows = [_augment_flow(flow) for flow in flows] if len(flows) == 1: flows = flows[0] return imgs, flows else: return imgs @DATASETS.register() class SRImageDataset(Dataset): """Paired image dataset for image restoration.""" def __init__(self, cfg): super(SRImageDataset, self).__init__() self.cfg = cfg self.file_client = None self.io_backend_opt = cfg['io_backend'] self.gt_folder, self.lq_folder = cfg['dataroot_gt'], cfg['dataroot_lq'] if 'filename_tmpl' in cfg: self.filename_tmpl = cfg['filename_tmpl'] else: self.filename_tmpl = '{}' if self.io_backend_opt['type'] == 'lmdb': #TODO: LielinJiang support lmdb to accelerate io pass elif 'meta_info_file' in self.cfg and self.cfg[ 'meta_info_file'] is not None: #TODO: LielinJiang support lmdb to accelerate io pass else: self.paths = paired_paths_from_folder( [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) def __getitem__(self, index): scale = self.cfg['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] lq_path = self.paths[index]['lq_path'] img_gt = cv2.imread(gt_path).astype(np.float32) / 255. img_lq = cv2.imread(lq_path).astype(np.float32) / 255. # augmentation for training if self.cfg['phase'] == 'train': gt_size = self.cfg['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.cfg['use_flip'], self.cfg['use_rot']) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor permute = transform.Permute() img_gt = permute(img_gt) img_lq = permute(img_lq) return { 'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path } def __len__(self): return len(self.paths)