未验证 提交 9d9f69f0 编写于 作者: L LielinJiang 提交者: GitHub

fix sr docs and add div2k process script (#154)

上级 95e5f4f3
......@@ -53,8 +53,8 @@ dataset:
keys: [image, image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
......
......@@ -18,8 +18,8 @@ model:
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
gt_folder: data/DIV2K/DIV2K_train_HR
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4
num_workers: 4
batch_size: 16
scale: 4
......@@ -49,8 +49,8 @@ dataset:
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
......
......@@ -65,8 +65,8 @@ dataset:
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
......
......@@ -45,8 +45,8 @@ dataset:
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
......
......@@ -69,8 +69,8 @@ dataset:
keys: [image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
......
......@@ -69,8 +69,8 @@ dataset:
keys: [image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
......
import os
import re
import sys
import cv2
import argparse
import numpy as np
import os.path as osp
from time import time
from multiprocessing import Pool
from shutil import get_terminal_size
from ppgan.datasets.base_dataset import scandir
class Timer:
"""A flexible Timer class."""
def __init__(self, start=True, print_tmpl=None):
self._is_running = False
self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
if start:
self.start()
@property
def is_running(self):
"""bool: indicate whether the timer is running"""
return self._is_running
def __enter__(self):
self.start()
return self
def __exit__(self, type, value, traceback):
print(self.print_tmpl.format(self.since_last_check()))
self._is_running = False
def start(self):
"""Start the timer."""
if not self._is_running:
self._t_start = time()
self._is_running = True
self._t_last = time()
def since_start(self):
"""Total time since the timer is started.
Returns (float): Time in seconds.
"""
if not self._is_running:
raise ValueError('timer is not running')
self._t_last = time()
return self._t_last - self._t_start
def since_last_check(self):
"""Time since the last checking.
Either :func:`since_start` or :func:`since_last_check` is a checking
operation.
Returns (float): Time in seconds.
"""
if not self._is_running:
raise ValueError('timer is not running')
dur = time() - self._t_last
self._t_last = time()
return dur
class ProgressBar:
"""A progress bar which can print the progress."""
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
self.task_num = task_num
self.bar_width = bar_width
self.completed = 0
self.file = file
if start:
self.start()
@property
def terminal_width(self):
width, _ = get_terminal_size()
return width
def start(self):
if self.task_num > 0:
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
'elapsed: 0s, ETA:')
else:
self.file.write('completed: 0, elapsed: 0s')
self.file.flush()
self.timer = Timer()
def update(self, num_tasks=1):
assert num_tasks > 0
self.completed += num_tasks
elapsed = self.timer.since_start()
if elapsed > 0:
fps = self.completed / elapsed
else:
fps = float('inf')
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
f'ETA: {eta:5}s'
bar_width = min(self.bar_width,
int(self.terminal_width - len(msg)) + 2,
int(self.terminal_width * 0.6))
bar_width = max(2, bar_width)
mark_width = int(bar_width * percentage)
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
self.file.write(msg.format(bar_chars))
else:
self.file.write(
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
f' {fps:.1f} tasks/s')
self.file.flush()
def main_extract_subimages(args):
"""A multi-thread tool to crop large images to sub-images for faster IO.
It is used for DIV2K dataset.
args (dict): Configuration dict. It contains:
n_thread (int): Thread number.
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
A higher value means a smaller size and longer compression time.
Use 0 for faster CPU decompression. Default: 3, same in cv2.
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower
than thresh_size will be dropped.
Usage:
For each folder, run this script.
Typically, there are four folders to be processed for DIV2K dataset.
DIV2K_train_HR
DIV2K_train_LR_bicubic/X2
DIV2K_train_LR_bicubic/X3
DIV2K_train_LR_bicubic/X4
After process, each sub_folder should have the same number of
subimages.
Remember to modify opt configurations according to your settings.
"""
opt = {}
opt['n_thread'] = args.n_thread
opt['compression_level'] = args.compression_level
# HR images
opt['input_folder'] = osp.join(args.data_root, 'DIV2K_train_HR')
opt['save_folder'] = osp.join(args.data_root, 'DIV2K_train_HR_sub')
opt['crop_size'] = args.crop_size
opt['step'] = args.step
opt['thresh_size'] = args.thresh_size
extract_subimages(opt)
for scale in [2, 3, 4]:
opt['input_folder'] = osp.join(args.data_root,
f'DIV2K_train_LR_bicubic/X{scale}')
opt['save_folder'] = osp.join(args.data_root,
f'DIV2K_train_LR_bicubic/X{scale}_sub')
opt['crop_size'] = args.crop_size // scale
opt['step'] = args.step // scale
opt['thresh_size'] = args.thresh_size // scale
extract_subimages(opt)
def extract_subimages(opt):
"""Crop images to subimages.
Args:
opt (dict): Configuration dict. It contains:
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
n_thread (int): Thread number.
"""
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print(f'mkdir {save_folder} ...')
else:
print(f'Folder {save_folder} already exists. Exit.')
sys.exit(1)
img_list = list(scandir(input_folder))
img_list = [osp.join(input_folder, v) for v in img_list]
prog_bar = ProgressBar(len(img_list))
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker,
args=(path, opt),
callback=lambda arg: prog_bar.update())
pool.close()
pool.join()
print('All processes done.')
def worker(path, opt):
"""Worker for each process.
Args:
path (str): Image path.
opt (dict): Configuration dict. It contains:
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is smaller
than thresh_size will be dropped.
save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
Returns:
process_info (str): Process information displayed in progress bar.
"""
crop_size = opt['crop_size']
step = opt['step']
thresh_size = opt['thresh_size']
img_name, extension = osp.splitext(osp.basename(path))
# remove the x2, x3, x4 and x8 in the filename for DIV2K
img_name = re.sub('x[2348]', '', img_name)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.ndim == 2 or img.ndim == 3:
h, w = img.shape[:2]
else:
raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')
h_space = np.arange(0, h - crop_size + 1, step)
if h - (h_space[-1] + crop_size) > thresh_size:
h_space = np.append(h_space, h - crop_size)
w_space = np.arange(0, w - crop_size + 1, step)
if w - (w_space[-1] + crop_size) > thresh_size:
w_space = np.append(w_space, w - crop_size)
index = 0
for x in h_space:
for y in w_space:
index += 1
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
cv2.imwrite(
osp.join(opt['save_folder'],
f'{img_name}_s{index:03d}{extension}'), cropped_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
process_info = f'Processing {img_name} ...'
return process_info
def parse_args():
parser = argparse.ArgumentParser(
description='Prepare DIV2K dataset',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data-root', help='dataset root')
parser.add_argument('--crop-size',
nargs='?',
default=480,
help='cropped size for HR images')
parser.add_argument('--step',
nargs='?',
default=240,
help='step size for HR images')
parser.add_argument('--thresh-size',
nargs='?',
default=0,
help='threshold size for HR images')
parser.add_argument('--compression-level',
nargs='?',
default=3,
help='compression level when save png images')
parser.add_argument('--n-thread',
nargs='?',
default=20,
help='thread number when using multiprocessing')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# extract subimages
main_extract_subimages(args)
......@@ -20,27 +20,36 @@
| Classical SR Testing | Set5 | Set5 test dataset | [Google Drive](https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u) / [Baidu Drive](https://pan.baidu.com/s/1q_1ERCMqALH0xFwjLM0pTg#list/path=%2Fsharelink2016187762-785433459861126%2Fclassical_SR_datasets&parentPath=%2Fsharelink2016187762-785433459861126) |
| Classical SR Testing | Set14 | Set14 test dataset | [Google Drive](https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u) / [Baidu Drive](https://pan.baidu.com/s/1q_1ERCMqALH0xFwjLM0pTg#list/path=%2Fsharelink2016187762-785433459861126%2Fclassical_SR_datasets&parentPath=%2Fsharelink2016187762-785433459861126) |
The structure of DIV2K is as following:
```
DIV2K
├── DIV2K_train_HR
├── DIV2K_train_LR_bicubic
| ├──X2
| ├──X3
| └──X4
├── DIV2K_valid_HR
├── DIV2K_valid_LR_bicubic
...
```
The structures of Set5 and Set14 are similar. Taking Set5 as an example, the structure is as following:
```
Set5
├── GTmod12
├── LRbicx2
├── LRbicx3
├── LRbicx4
└── original
The structure of DIV2K, Set5 and Set14 is as following:
```
PaddleGAN
├── data
├── DIV2K
├── DIV2K_train_HR
├── DIV2K_train_LR_bicubic
| ├──X2
| ├──X3
| └──X4
├── DIV2K_valid_HR
├── DIV2K_valid_LR_bicubic
Set5
├── GTmod12
├── LRbicx2
├── LRbicx3
├── LRbicx4
└── original
Set14
├── GTmod12
├── LRbicx2
├── LRbicx3
├── LRbicx4
└── original
...
```
Use the following commands to process the DIV2K data set:
```
python data/process_div2k_data.py --data-root data/DIV2K
```
......@@ -71,6 +80,7 @@ The metrics are PSNR / SSIM.
| lesrcnn_x4 | 31.9476 / 0.8909 | 28.4110 / 0.7770 | 30.231 / 0.8326 |
| esrgan_psnr_x4 | 32.5512 / 0.8991 | 28.8114 / 0.7871 | 30.7565 / 0.8449 |
| esrgan_x4 | 28.7647 / 0.8187 | 25.0065 / 0.6762 | 26.9013 / 0.7542 |
| drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - |
<!-- ![](../../imgs/horse2zebra.png) -->
......
......@@ -72,6 +72,23 @@ class Trainer:
# save checkpoint (model.nets) \/
"""
def __init__(self, cfg):
# base config
self.logger = logging.getLogger(__name__)
self.cfg = cfg
self.output_dir = cfg.output_dir
self.max_eval_steps = cfg.model.get('max_eval_steps', None)
self.local_rank = ParallelEnv().local_rank
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
self.weight_interval = cfg.snapshot_config.interval
self.start_epoch = 1
self.current_epoch = 1
self.current_iter = 1
self.inner_iter = 1
self.batch_id = 0
self.global_steps = 0
# build model
self.model = build_model(cfg.model)
......@@ -79,6 +96,21 @@ class Trainer:
if ParallelEnv().nranks > 1:
self.distributed_data_parallel()
# build metrics
self.metrics = None
validate_cfg = cfg.get('validate', None)
if validate_cfg and 'metrics' in validate_cfg:
self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl:
import visualdl
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
# evaluate only
if not cfg.is_train:
return
# build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train)
self.iters_per_epoch = len(self.train_dataloader)
......@@ -93,21 +125,6 @@ class Trainer:
self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
cfg.optimizer)
# build metrics
self.metrics = None
validate_cfg = cfg.get('validate', None)
if validate_cfg and 'metrics' in validate_cfg:
self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
self.logger = logging.getLogger(__name__)
self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl:
import visualdl
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
# base config
self.output_dir = cfg.output_dir
self.max_eval_steps = cfg.model.get('max_eval_steps', None)
self.epochs = cfg.get('epochs', None)
if self.epochs:
self.total_iters = self.epochs * self.iters_per_epoch
......@@ -116,26 +133,12 @@ class Trainer:
self.by_epoch = False
self.total_iters = cfg.total_iters
self.start_epoch = 1
self.current_epoch = 1
self.current_iter = 1
self.inner_iter = 1
self.batch_id = 0
self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.validate_interval = -1
if cfg.get('validate', None) is not None:
self.validate_interval = cfg.validate.get('interval', -1)
self.cfg = cfg
self.local_rank = ParallelEnv().local_rank
self.time_count = {}
self.best_metric = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册