# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import cv2 import random import numpy as np import paddle.vision.transforms as transform from pathlib import Path from paddle.io import Dataset from .builder import DATASETS def scandir(dir_path, suffix=None, recursive=False): """Scan a directory to find the interested files. """ if isinstance(dir_path, (str, Path)): dir_path = str(dir_path) else: raise TypeError('"dir_path" must be a string or Path object') if (suffix is not None) and not isinstance(suffix, (str, tuple)): raise TypeError('"suffix" must be a string or tuple of strings') root = dir_path def _scandir(dir_path, suffix, recursive): for entry in os.scandir(dir_path): if not entry.name.startswith('.') and entry.is_file(): rel_path = os.path.relpath(entry.path, root) if suffix is None: yield rel_path elif rel_path.endswith(suffix): yield rel_path else: if recursive: yield from _scandir(entry.path, suffix=suffix, recursive=recursive) else: continue return _scandir(dir_path, suffix=suffix, recursive=recursive) def paired_paths_from_folder(folders, keys, filename_tmpl): """Generate paired paths from folders. """ assert len(folders) == 2, ( 'The len of folders should be 2 with [input_folder, gt_folder]. ' f'But got {len(folders)}') assert len(keys) == 2, ( 'The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') input_folder, gt_folder = folders input_key, gt_key = keys input_paths = list(scandir(input_folder)) gt_paths = list(scandir(gt_folder)) assert len(input_paths) == len(gt_paths), ( f'{input_key} and {gt_key} datasets have different number of images: ' f'{len(input_paths)}, {len(gt_paths)}.') paths = [] for gt_path in gt_paths: basename, ext = os.path.splitext(os.path.basename(gt_path)) input_name = f'{filename_tmpl.format(basename)}{ext}' input_path = os.path.join(input_folder, input_name) assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.') gt_path = os.path.join(gt_folder, gt_path) paths.append( dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) return paths def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): """Paired random crop. It crops lists of lq and gt images with corresponding locations. Args: img_gts (list[ndarray] | ndarray): GT images. Note that all images should have the same shape. If the input is an ndarray, it will be transformed to a list containing itself. img_lqs (list[ndarray] | ndarray): LQ images. Note that all images should have the same shape. If the input is an ndarray, it will be transformed to a list containing itself. gt_patch_size (int): GT patch size. scale (int): Scale factor. gt_path (str): Path to ground-truth. Returns: list[ndarray] | ndarray: GT images and LQ images. If returned results only have one element, just return ndarray. """ if not isinstance(img_gts, list): img_gts = [img_gts] if not isinstance(img_lqs, list): img_lqs = [img_lqs] h_lq, w_lq, _ = img_lqs[0].shape h_gt, w_gt, _ = img_gts[0].shape lq_patch_size = gt_patch_size // scale if h_gt != h_lq * scale or w_gt != w_lq * scale: raise ValueError( f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', f'multiplication of LQ ({h_lq}, {w_lq}).') if h_lq < lq_patch_size or w_lq < lq_patch_size: raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' f'({lq_patch_size}, {lq_patch_size}). ' f'Please remove {gt_path}.') # randomly choose top and left coordinates for lq patch top = random.randint(0, h_lq - lq_patch_size) left = random.randint(0, w_lq - lq_patch_size) # crop lq patch img_lqs = [ v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs ] # crop corresponding gt patch top_gt, left_gt = int(top * scale), int(left * scale) img_gts = [ v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts ] if len(img_gts) == 1: img_gts = img_gts[0] if len(img_lqs) == 1: img_lqs = img_lqs[0] return img_gts, img_lqs def augment(imgs, hflip=True, rotation=True, flows=None): """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). """ hflip = hflip and random.random() < 0.5 vflip = rotation and random.random() < 0.5 rot90 = rotation and random.random() < 0.5 def _augment(img): if hflip: cv2.flip(img, 1, img) if vflip: cv2.flip(img, 0, img) if rot90: img = img.transpose(1, 0, 2) return img def _augment_flow(flow): if hflip: cv2.flip(flow, 1, flow) flow[:, :, 0] *= -1 if vflip: cv2.flip(flow, 0, flow) flow[:, :, 1] *= -1 if rot90: flow = flow.transpose(1, 0, 2) flow = flow[:, :, [1, 0]] return flow if not isinstance(imgs, list): imgs = [imgs] imgs = [_augment(img) for img in imgs] if len(imgs) == 1: imgs = imgs[0] if flows is not None: if not isinstance(flows, list): flows = [flows] flows = [_augment_flow(flow) for flow in flows] if len(flows) == 1: flows = flows[0] return imgs, flows else: return imgs @DATASETS.register() class SRImageDataset(Dataset): """Paired image dataset for image restoration.""" def __init__(self, cfg): super(SRImageDataset, self).__init__() self.cfg = cfg self.file_client = None self.io_backend_opt = cfg['io_backend'] self.gt_folder, self.lq_folder = cfg['dataroot_gt'], cfg['dataroot_lq'] if 'filename_tmpl' in cfg: self.filename_tmpl = cfg['filename_tmpl'] else: self.filename_tmpl = '{}' if self.io_backend_opt['type'] == 'lmdb': #TODO: LielinJiang support lmdb to accelerate io pass elif 'meta_info_file' in self.cfg and self.cfg[ 'meta_info_file'] is not None: #TODO: LielinJiang support lmdb to accelerate io pass else: self.paths = paired_paths_from_folder( [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) def __getitem__(self, index): scale = self.cfg['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] lq_path = self.paths[index]['lq_path'] img_gt = cv2.imread(gt_path).astype(np.float32) / 255. img_lq = cv2.imread(lq_path).astype(np.float32) / 255. # augmentation for training if self.cfg['phase'] == 'train': gt_size = self.cfg['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.cfg['use_flip'], self.cfg['use_rot']) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor permute = transform.Permute() img_gt = permute(img_gt) img_lq = permute(img_lq) return { 'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path } def __len__(self): return len(self.paths)