From 424ab9eb5db031d178f6048b2f13aa0472ecfaab Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Wed, 15 Dec 2021 20:07:39 +0800 Subject: [PATCH] fix reds dataset (#527) * fix reds dataset * fix CI --- configs/basicvsr++_reds.yaml | 70 ++-- configs/basicvsr++_vimeo90k_BD.yaml | 4 +- configs/basicvsr_reds.yaml | 75 +++-- configs/edvr_l_blur_w_tsa.yaml | 92 +++-- configs/edvr_l_blur_wo_tsa.yaml | 92 +++-- configs/edvr_l_w_tsa.yaml | 88 +++-- configs/edvr_l_wo_tsa.yaml | 88 +++-- configs/edvr_m_w_tsa.yaml | 88 +++-- configs/edvr_m_wo_tsa.yaml | 88 +++-- configs/esrgan_psnr_x4_div2k.yaml | 4 +- configs/esrgan_x4_div2k.yaml | 4 +- configs/iconvsr_reds.yaml | 72 ++-- configs/lesrcnn_psnr_x4_div2k.yaml | 4 +- configs/msvsr_l_reds.yaml | 71 ++-- configs/msvsr_reds.yaml | 75 +++-- configs/msvsr_vimeo90k_BD.yaml | 5 +- configs/realsr_bicubic_noise_x4_df2k.yaml | 4 +- configs/realsr_kernel_noise_x4_dped.yaml | 4 +- ppgan/datasets/__init__.py | 4 +- ppgan/datasets/edvr_dataset.py | 313 ------------------ ppgan/datasets/preprocess/__init__.py | 2 +- ppgan/datasets/preprocess/io.py | 148 +++++++++ ppgan/datasets/sr_reds_multiple_gt_dataset.py | 237 ------------- ppgan/datasets/vsr_reds_dataset.py | 97 ++++++ .../datasets/vsr_reds_multiple_gt_dataset.py | 92 +++++ ppgan/models/edvr_model.py | 2 +- 26 files changed, 968 insertions(+), 855 deletions(-) delete mode 100644 ppgan/datasets/edvr_dataset.py delete mode 100644 ppgan/datasets/sr_reds_multiple_gt_dataset.py create mode 100644 ppgan/datasets/vsr_reds_dataset.py create mode 100644 ppgan/datasets/vsr_reds_multiple_gt_dataset.py diff --git a/configs/basicvsr++_reds.yaml b/configs/basicvsr++_reds.yaml index 61a8e50..f10e8e0 100644 --- a/configs/basicvsr++_reds.yaml +++ b/configs/basicvsr++_reds.yaml @@ -27,33 +27,61 @@ dataset: num_workers: 4 batch_size: 2 #4 gpus dataset: - name: SRREDSMultipleGTDataset - mode: train + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/train_sharp_bicubic/X4 gt_folder: data/REDS/train_sharp/X4 - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 30 - use_flip: True - use_rot: True - scale: 4 - val_partition: REDS4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 30 + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: SRREDSMultipleGTDataset - mode: test + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 gt_folder: data/REDS/REDS4_test_sharp/X4 - interval_list: [1] - random_reverse: False - number_frames: 100 - use_flip: False - use_rot: False - scale: 4 - val_partition: REDS4 - num_workers: 0 - batch_size: 1 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 100 + test_mode: True + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/basicvsr++_vimeo90k_BD.yaml b/configs/basicvsr++_vimeo90k_BD.yaml index ec4294d..95c1634 100644 --- a/configs/basicvsr++_vimeo90k_BD.yaml +++ b/configs/basicvsr++_vimeo90k_BD.yaml @@ -53,7 +53,7 @@ dataset: keys: [image, image] - name: MirrorVideoSequence - name: NormalizeSequence - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] @@ -80,7 +80,7 @@ dataset: - name: TransposeSequence keys: [image, image] - name: NormalizeSequence - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] diff --git a/configs/basicvsr_reds.yaml b/configs/basicvsr_reds.yaml index 172764f..f803a3c 100644 --- a/configs/basicvsr_reds.yaml +++ b/configs/basicvsr_reds.yaml @@ -24,37 +24,64 @@ dataset: name: RepeatDataset times: 1000 num_workers: 4 - batch_size: 2 #4 GPUs + batch_size: 2 #4 gpus dataset: - name: SRREDSMultipleGTDataset - mode: train + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/train_sharp_bicubic/X4 gt_folder: data/REDS/train_sharp/X4 - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 15 - use_flip: True - use_rot: True - scale: 4 - val_partition: REDS4 - num_clips: 270 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 15 + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: SRREDSMultipleGTDataset - mode: test + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 gt_folder: data/REDS/REDS4_test_sharp/X4 - interval_list: [1] - random_reverse: False - number_frames: 100 - use_flip: False - use_rot: False - scale: 4 - val_partition: REDS4 - num_workers: 0 - batch_size: 1 - num_clips: 270 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 100 + test_mode: True + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] + lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/edvr_l_blur_w_tsa.yaml b/configs/edvr_l_blur_w_tsa.yaml index 91ea165..b89440d 100644 --- a/configs/edvr_l_blur_w_tsa.yaml +++ b/configs/edvr_l_blur_w_tsa.yaml @@ -20,47 +20,79 @@ model: front_RBs: 5 back_RBs: 40 center: 2 - predeblur: True #False - HR_in: True #False + predeblur: True + HR_in: True w_TSA: True pixel_criterion: name: CharbonnierLoss dataset: train: - name: REDSDataset - mode: train - gt_folder: data/REDS/train_sharp/X4 - lq_folder: data/REDS/train_blur/X4 - img_format: png - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - use_flip: True - use_rot: True - buf_size: 1024 - scale: 1 - fix_random_seed: 10 + name: RepeatDataset + times: 1000 num_workers: 6 - batch_size: 8 + batch_size: 8 #4 gpus + dataset: + name: VSRREDSDataset + lq_folder: data/REDS/train_blur/X4 + gt_folder: data/REDS/train_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: False + preprocess: + - name: GetFrameIdx + interval_list: [1] + frames_per_clip: 99 + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 1 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: REDSDataset - mode: test - gt_folder: data/REDS/REDS4_test_sharp/X4 + name: VSRREDSDataset lq_folder: data/REDS/REDS4_test_blur/X4 - img_format: png - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 1 - use_flip: False - use_rot: False - buf_size: 1024 - scale: 1 - fix_random_seed: 10 + gt_folder: data/REDS/REDS4_test_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: True + preprocess: + - name: GetFrameIdxwithPadding + padding: reflection_circle + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/edvr_l_blur_wo_tsa.yaml b/configs/edvr_l_blur_wo_tsa.yaml index 214f054..a6b3ee4 100644 --- a/configs/edvr_l_blur_wo_tsa.yaml +++ b/configs/edvr_l_blur_wo_tsa.yaml @@ -19,47 +19,79 @@ model: front_RBs: 5 back_RBs: 40 center: 2 - predeblur: True #False - HR_in: True #False + predeblur: True + HR_in: True w_TSA: False pixel_criterion: name: CharbonnierLoss dataset: train: - name: REDSDataset - mode: train - gt_folder: data/REDS/train_sharp/X4 - lq_folder: data/REDS/train_blur/X4 - img_format: png - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - use_flip: True - use_rot: True - buf_size: 1024 - scale: 1 - fix_random_seed: 10 + name: RepeatDataset + times: 1000 num_workers: 6 - batch_size: 8 + batch_size: 8 #4 gpus + dataset: + name: VSRREDSDataset + lq_folder: data/REDS/train_blur/X4 + gt_folder: data/REDS/train_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: False + preprocess: + - name: GetFrameIdx + interval_list: [1] + frames_per_clip: 99 + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 1 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: REDSDataset - mode: test - gt_folder: data/REDS/REDS4_test_sharp/X4 + name: VSRREDSDataset lq_folder: data/REDS/REDS4_test_blur/X4 - img_format: png - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 1 - use_flip: False - use_rot: False - buf_size: 1024 - scale: 1 - fix_random_seed: 10 + gt_folder: data/REDS/REDS4_test_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: True + preprocess: + - name: GetFrameIdxwithPadding + padding: reflection_circle + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/edvr_l_w_tsa.yaml b/configs/edvr_l_w_tsa.yaml index 6c17929..fa79ded 100644 --- a/configs/edvr_l_w_tsa.yaml +++ b/configs/edvr_l_w_tsa.yaml @@ -28,39 +28,71 @@ model: dataset: train: - name: REDSDataset - mode: train - gt_folder: data/REDS/train_sharp/X4 - lq_folder: data/REDS/train_sharp_bicubic/X4 - img_format: png - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - use_flip: True - use_rot: True - buf_size: 1024 - scale: 4 - fix_random_seed: 10 + name: RepeatDataset + times: 1000 num_workers: 3 - batch_size: 4 # 8GPUs + batch_size: 4 #8 gpus + dataset: + name: VSRREDSDataset + lq_folder: data/REDS/train_sharp_bicubic/X4 + gt_folder: data/REDS/train_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: False + preprocess: + - name: GetFrameIdx + interval_list: [1] + frames_per_clip: 99 + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: REDSDataset - mode: test - gt_folder: data/REDS/REDS4_test_sharp/X4 + name: VSRREDSDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 - img_format: png - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 1 - use_flip: False - use_rot: False - buf_size: 1024 - scale: 4 - fix_random_seed: 10 + gt_folder: data/REDS/REDS4_test_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: True + preprocess: + - name: GetFrameIdxwithPadding + padding: reflection_circle + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/edvr_l_wo_tsa.yaml b/configs/edvr_l_wo_tsa.yaml index d209010..891c009 100644 --- a/configs/edvr_l_wo_tsa.yaml +++ b/configs/edvr_l_wo_tsa.yaml @@ -27,39 +27,71 @@ model: dataset: train: - name: REDSDataset - mode: train - gt_folder: data/REDS/train_sharp/X4 - lq_folder: data/REDS/train_sharp_bicubic/X4 - img_format: png - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - use_flip: True - use_rot: True - buf_size: 1024 - scale: 4 - fix_random_seed: 10 + name: RepeatDataset + times: 1000 num_workers: 3 - batch_size: 4 # 8GPUs + batch_size: 4 #8 gpus + dataset: + name: VSRREDSDataset + lq_folder: data/REDS/train_sharp_bicubic/X4 + gt_folder: data/REDS/train_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: False + preprocess: + - name: GetFrameIdx + interval_list: [1] + frames_per_clip: 99 + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: REDSDataset - mode: test - gt_folder: data/REDS/REDS4_test_sharp/X4 + name: VSRREDSDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 - img_format: png - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 1 - use_flip: False - use_rot: False - buf_size: 1024 - scale: 4 - fix_random_seed: 10 + gt_folder: data/REDS/REDS4_test_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: True + preprocess: + - name: GetFrameIdxwithPadding + padding: reflection_circle + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/edvr_m_w_tsa.yaml b/configs/edvr_m_w_tsa.yaml index aef89cb..ce14a40 100644 --- a/configs/edvr_m_w_tsa.yaml +++ b/configs/edvr_m_w_tsa.yaml @@ -31,39 +31,71 @@ export_model: dataset: train: - name: REDSDataset - mode: train - gt_folder: data/REDS/train_sharp/X4 - lq_folder: data/REDS/train_sharp_bicubic/X4 - img_format: png - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - use_flip: True - use_rot: True - buf_size: 1024 - scale: 4 - fix_random_seed: 10 + name: RepeatDataset + times: 1000 num_workers: 3 - batch_size: 4 # 8GPUs + batch_size: 4 #8 gpus + dataset: + name: VSRREDSDataset + lq_folder: data/REDS/train_sharp_bicubic/X4 + gt_folder: data/REDS/train_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: False + preprocess: + - name: GetFrameIdx + interval_list: [1] + frames_per_clip: 99 + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: REDSDataset - mode: test - gt_folder: data/REDS/REDS4_test_sharp/X4 + name: VSRREDSDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 - img_format: png - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 1 - use_flip: False - use_rot: False - buf_size: 1024 - scale: 4 - fix_random_seed: 10 + gt_folder: data/REDS/REDS4_test_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: True + preprocess: + - name: GetFrameIdxwithPadding + padding: reflection_circle + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/edvr_m_wo_tsa.yaml b/configs/edvr_m_wo_tsa.yaml index d6502ed..5c3b061 100644 --- a/configs/edvr_m_wo_tsa.yaml +++ b/configs/edvr_m_wo_tsa.yaml @@ -27,39 +27,71 @@ model: dataset: train: - name: REDSDataset - mode: train - gt_folder: data/REDS/train_sharp/X4 - lq_folder: data/REDS/train_sharp_bicubic/X4 - img_format: png - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - use_flip: True - use_rot: True - buf_size: 1024 - scale: 4 - fix_random_seed: 10 + name: RepeatDataset + times: 1000 num_workers: 3 - batch_size: 4 # 8GPUs + batch_size: 4 #8 gpus + dataset: + name: VSRREDSDataset + lq_folder: data/REDS/train_sharp_bicubic/X4 + gt_folder: data/REDS/train_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: False + preprocess: + - name: GetFrameIdx + interval_list: [1] + frames_per_clip: 99 + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: REDSDataset - mode: test - gt_folder: data/REDS/REDS4_test_sharp/X4 + name: VSRREDSDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 - img_format: png - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 1 - use_flip: False - use_rot: False - buf_size: 1024 - scale: 4 - fix_random_seed: 10 + gt_folder: data/REDS/REDS4_test_sharp/X4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 5 + val_partition: REDS4 + test_mode: True + preprocess: + - name: GetFrameIdxwithPadding + padding: reflection_circle + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/esrgan_psnr_x4_div2k.yaml b/configs/esrgan_psnr_x4_div2k.yaml index 5df56c9..ca3486d 100644 --- a/configs/esrgan_psnr_x4_div2k.yaml +++ b/configs/esrgan_psnr_x4_div2k.yaml @@ -44,7 +44,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] test: @@ -63,7 +63,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] diff --git a/configs/esrgan_x4_div2k.yaml b/configs/esrgan_x4_div2k.yaml index 5202389..aa9e9ab 100644 --- a/configs/esrgan_x4_div2k.yaml +++ b/configs/esrgan_x4_div2k.yaml @@ -64,7 +64,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] test: @@ -83,7 +83,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] diff --git a/configs/iconvsr_reds.yaml b/configs/iconvsr_reds.yaml index 32ea630..c3f8262 100644 --- a/configs/iconvsr_reds.yaml +++ b/configs/iconvsr_reds.yaml @@ -24,35 +24,63 @@ dataset: name: RepeatDataset times: 1000 num_workers: 4 - batch_size: 2 #4 GPUs + batch_size: 2 #4 gpus dataset: - name: SRREDSMultipleGTDataset - mode: train + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/train_sharp_bicubic/X4 gt_folder: data/REDS/train_sharp/X4 - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 15 - use_flip: True - use_rot: True - scale: 4 - val_partition: REDS4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 15 + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: SRREDSMultipleGTDataset - mode: test + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 gt_folder: data/REDS/REDS4_test_sharp/X4 - interval_list: [1] - random_reverse: False - number_frames: 100 - use_flip: False - use_rot: False - scale: 4 - val_partition: REDS4 - num_workers: 0 - batch_size: 1 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 100 + test_mode: True + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/lesrcnn_psnr_x4_div2k.yaml b/configs/lesrcnn_psnr_x4_div2k.yaml index 5f0dba0..6591be2 100644 --- a/configs/lesrcnn_psnr_x4_div2k.yaml +++ b/configs/lesrcnn_psnr_x4_div2k.yaml @@ -40,7 +40,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] test: @@ -59,7 +59,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] diff --git a/configs/msvsr_l_reds.yaml b/configs/msvsr_l_reds.yaml index 6f5c022..7d42fa5 100644 --- a/configs/msvsr_l_reds.yaml +++ b/configs/msvsr_l_reds.yaml @@ -34,35 +34,62 @@ dataset: times: 1000 num_workers: 4 batch_size: 2 #8 gpus - use_shared_memory: True dataset: - name: SRREDSMultipleGTDataset - mode: train + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/train_sharp_bicubic/X4 gt_folder: data/REDS/train_sharp/X4 - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 30 - use_flip: True - use_rot: True - scale: 4 - val_partition: REDS4 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 30 + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: SRREDSMultipleGTDataset - mode: test + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 gt_folder: data/REDS/REDS4_test_sharp/X4 - interval_list: [1] - random_reverse: False - number_frames: 100 - use_flip: False - use_rot: False - scale: 4 - val_partition: REDS4 - num_workers: 0 - batch_size: 1 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 100 + test_mode: True + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/msvsr_reds.yaml b/configs/msvsr_reds.yaml index 7767780..2efb181 100644 --- a/configs/msvsr_reds.yaml +++ b/configs/msvsr_reds.yaml @@ -34,37 +34,62 @@ dataset: times: 1000 num_workers: 6 batch_size: 2 #8 gpus - use_shared_memory: True dataset: - name: SRREDSMultipleGTDataset - mode: train + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/train_sharp_bicubic/X4 gt_folder: data/REDS/train_sharp/X4 - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 20 - use_flip: True - use_rot: True - scale: 4 - val_partition: REDS4 - num_clips: 270 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 20 + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] test: - name: SRREDSMultipleGTDataset - mode: test + name: VSRREDSMultipleGTDataset lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 gt_folder: data/REDS/REDS4_test_sharp/X4 - interval_list: [1] - random_reverse: False - number_frames: 100 - use_flip: False - use_rot: False - scale: 4 - val_partition: REDS4 - num_workers: 0 - batch_size: 1 - num_clips: 270 + ann_file: data/REDS/meta_info_REDS_GT.txt + num_frames: 100 + test_mode: True + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., 0., 0.] + std: [255., 255., 255.] + keys: [image, image] lr_scheduler: name: CosineAnnealingRestartLR @@ -104,4 +129,4 @@ snapshot_config: interval: 5000 export_model: - - {name: 'generator', inputs_num: 1} \ No newline at end of file + - {name: 'generator', inputs_num: 1} diff --git a/configs/msvsr_vimeo90k_BD.yaml b/configs/msvsr_vimeo90k_BD.yaml index 7496c3a..a1d0f86 100644 --- a/configs/msvsr_vimeo90k_BD.yaml +++ b/configs/msvsr_vimeo90k_BD.yaml @@ -36,7 +36,6 @@ dataset: batch_size: 2 #8 gpus dataset: name: VSRVimeo90KDataset - # mode: train lq_folder: data/vimeo90k/vimeo_septuplet_BD_matlabLRx4/sequences gt_folder: data/vimeo90k/vimeo_septuplet/sequences ann_file: data/vimeo90k/vimeo_septuplet/sep_trainlist.txt @@ -62,7 +61,7 @@ dataset: keys: [image, image] - name: MirrorVideoSequence - name: NormalizeSequence - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] @@ -89,7 +88,7 @@ dataset: - name: TransposeSequence keys: [image, image] - name: NormalizeSequence - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] diff --git a/configs/realsr_bicubic_noise_x4_df2k.yaml b/configs/realsr_bicubic_noise_x4_df2k.yaml index 6200825..0a19753 100644 --- a/configs/realsr_bicubic_noise_x4_df2k.yaml +++ b/configs/realsr_bicubic_noise_x4_df2k.yaml @@ -61,7 +61,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] - name: SRNoise @@ -84,7 +84,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] diff --git a/configs/realsr_kernel_noise_x4_dped.yaml b/configs/realsr_kernel_noise_x4_dped.yaml index c2e6eab..ba2851e 100644 --- a/configs/realsr_kernel_noise_x4_dped.yaml +++ b/configs/realsr_kernel_noise_x4_dped.yaml @@ -61,7 +61,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] - name: SRNoise @@ -84,7 +84,7 @@ dataset: - name: Transpose keys: [image, image] - name: Normalize - mean: [0., .0, 0.] + mean: [0., 0., 0.] std: [255., 255., 255.] keys: [image, image] diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index e402927..731eb7c 100755 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -21,11 +21,11 @@ from .common_vision_dataset import CommonVisionDataset from .animeganv2_dataset import AnimeGANV2Dataset from .wav2lip_dataset import Wav2LipDataset from .starganv2_dataset import StarGANv2Dataset -from .edvr_dataset import REDSDataset from .firstorder_dataset import FirstOrderDataset from .lapstyle_dataset import LapStyleDataset -from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset from .mpr_dataset import MPRTrain, MPRVal, MPRTest +from .vsr_reds_dataset import VSRREDSDataset +from .vsr_reds_multiple_gt_dataset import VSRREDSMultipleGTDataset from .vsr_vimeo90k_dataset import VSRVimeo90KDataset from .vsr_folder_dataset import VSRFolderDataset from .photopen_dataset import PhotoPenDataset diff --git a/ppgan/datasets/edvr_dataset.py b/ppgan/datasets/edvr_dataset.py deleted file mode 100644 index c9b587c..0000000 --- a/ppgan/datasets/edvr_dataset.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import random -import numpy as np -import scipy.io as scio -import cv2 -import paddle -from paddle.io import Dataset, DataLoader -from .builder import DATASETS - -logger = logging.getLogger(__name__) - - -@DATASETS.register() -class REDSDataset(Dataset): - """ - REDS dataset for EDVR model - """ - def __init__(self, - mode, - lq_folder, - gt_folder, - img_format="png", - crop_size=256, - interval_list=[1], - random_reverse=False, - number_frames=5, - batch_size=32, - use_flip=False, - use_rot=False, - buf_size=1024, - scale=4, - fix_random_seed=False): - super(REDSDataset, self).__init__() - self.format = img_format - self.mode = mode - self.crop_size = crop_size - self.interval_list = interval_list - self.random_reverse = random_reverse - self.number_frames = number_frames - self.batch_size = batch_size - self.fileroot = lq_folder - self.use_flip = use_flip - self.use_rot = use_rot - self.buf_size = buf_size - self.fix_random_seed = fix_random_seed - - if self.mode != 'infer': - self.gtroot = gt_folder - self.scale = scale - self.LR_input = (self.scale > 1) - if self.fix_random_seed: - random.seed(10) - np.random.seed(10) - self.num_reader_threads = 1 - - self._init_() - - def _init_(self): - logger.info('initialize reader ... ') - print("initialize reader") - self.filelist = [] - for video_name in os.listdir(self.fileroot): - if (self.mode == 'train') and (video_name in [ - '000', '011', '015', '020' - ]): #These four videos are used as val - continue - for frame_name in os.listdir(os.path.join(self.fileroot, - video_name)): - frame_idx = frame_name.split('.')[0] - video_frame_idx = video_name + '_' + str(frame_idx) - # for each item in self.filelist is like '010_00000015', '260_00000090' - self.filelist.append(video_frame_idx) - if self.mode == 'test': - self.filelist.sort() - print(len(self.filelist)) - - def __getitem__(self, index): - """Get training sample - - return: lq:[5,3,W,H], - gt:[3,W,H], - lq_path:str - """ - item = self.filelist[index] - img_LQs, img_GT = self.get_sample_data( - item, self.number_frames, self.interval_list, self.random_reverse, - self.gtroot, self.fileroot, self.LR_input, self.crop_size, - self.scale, self.use_flip, self.use_rot, self.mode) - return {'lq': img_LQs, 'gt': img_GT, 'lq_path': self.filelist[index]} - - def get_sample_data(self, - item, - number_frames, - interval_list, - random_reverse, - gtroot, - fileroot, - LR_input, - crop_size, - scale, - use_flip, - use_rot, - mode='train'): - video_name = item.split('_')[0] - frame_name = item.split('_')[1] - if (mode == 'train') or (mode == 'valid'): - ngb_frames, name_b = self.get_neighbor_frames(frame_name, \ - number_frames=number_frames, \ - interval_list=interval_list, \ - random_reverse=random_reverse) - elif mode == 'test': - ngb_frames, name_b = self.get_test_neighbor_frames( - int(frame_name), number_frames) - else: - raise NotImplementedError('mode {} not implemented'.format(mode)) - frame_name = name_b - img_GT = self.read_img( - os.path.join(gtroot, video_name, frame_name + '.png')) - frame_list = [] - for ngb_frm in ngb_frames: - ngb_name = "%08d" % ngb_frm - img = self.read_img( - os.path.join(fileroot, video_name, ngb_name + '.png')) - frame_list.append(img) - H, W, C = frame_list[0].shape - # add random crop - if (mode == 'train') or (mode == 'valid'): - if LR_input: - LQ_size = crop_size // scale - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - frame_list = [ - v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] - for v in frame_list - ] - rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) - img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, - rnd_w_HR:rnd_w_HR + crop_size, :] - else: - rnd_h = random.randint(0, max(0, H - crop_size)) - rnd_w = random.randint(0, max(0, W - crop_size)) - frame_list = [ - v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] - for v in frame_list - ] - img_GT = img_GT[rnd_h:rnd_h + crop_size, - rnd_w:rnd_w + crop_size, :] - - # add random flip and rotation - frame_list.append(img_GT) - if (mode == 'train') or (mode == 'valid'): - rlt = self.img_augment(frame_list, use_flip, use_rot) - else: - rlt = frame_list - frame_list = rlt[0:-1] - img_GT = rlt[-1] - - # stack LQ images to NHWC, N is the frame number - img_LQs = np.stack(frame_list, axis=0) - # BGR to RGB, HWC to CHW, numpy to tensor - img_GT = img_GT[:, :, [2, 1, 0]] - img_LQs = img_LQs[:, :, :, [2, 1, 0]] - img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32') - img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') - - return img_LQs, img_GT - - def get_neighbor_frames(self, - frame_name, - number_frames, - interval_list, - random_reverse, - max_frame=99, - bordermode=False): - center_frame_idx = int(frame_name) - half_N_frames = number_frames // 2 - interval = random.choice(interval_list) - if bordermode: - direction = 1 - if random_reverse and random.random() < 0.5: - direction = random.choice([0, 1]) - if center_frame_idx + interval * (number_frames - 1) > max_frame: - direction = 0 - elif center_frame_idx - interval * (number_frames - 1) < 0: - direction = 1 - if direction == 1: - neighbor_list = list( - range(center_frame_idx, - center_frame_idx + interval * number_frames, - interval)) - else: - neighbor_list = list( - range(center_frame_idx, - center_frame_idx - interval * number_frames, - -interval)) - name_b = '{:08d}'.format(neighbor_list[0]) - else: - # ensure not exceeding the borders - while (center_frame_idx + half_N_frames * interval > max_frame) or ( - center_frame_idx - half_N_frames * interval < 0): - center_frame_idx = random.randint(0, max_frame) - neighbor_list = list( - range(center_frame_idx - half_N_frames * interval, - center_frame_idx + half_N_frames * interval + 1, - interval)) - if random_reverse and random.random() < 0.5: - neighbor_list.reverse() - name_b = '{:08d}'.format(neighbor_list[half_N_frames]) - assert len(neighbor_list) == number_frames, \ - "frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames) - - return neighbor_list, name_b - - def read_img(self, path, size=None): - """read image by cv2 - - return: Numpy float32, HWC, BGR, [0,1] - """ - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - img = img.astype(np.float32) / 255. - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - # some images have 4 channels - if img.shape[2] > 3: - img = img[:, :, :3] - return img - - def img_augment(self, img_list, hflip=True, rot=True): - """horizontal flip OR rotate (0, 90, 180, 270 degrees) - """ - hflip = hflip and random.random() < 0.5 - vflip = rot and random.random() < 0.5 - rot90 = rot and random.random() < 0.5 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - return img - - return [_augment(img) for img in img_list] - - def get_test_neighbor_frames(self, crt_i, N, max_n=100, padding='new_info'): - """Generate an index list for reading N frames from a sequence of images - Args: - crt_i (int): current center index - max_n (int): max number of the sequence of images (calculated from 1) - N (int): reading N frames - padding (str): padding mode, one of replicate | reflection | new_info | circle - Example: crt_i = 0, N = 5 - replicate: [0, 0, 0, 1, 2] - reflection: [2, 1, 0, 1, 2] - new_info: [4, 3, 0, 1, 2] - circle: [3, 4, 0, 1, 2] - - Returns: - return_l (list [int]): a list of indexes - """ - max_n = max_n - 1 - n_pad = N // 2 - return_l = [] - - for i in range(crt_i - n_pad, crt_i + n_pad + 1): - if i < 0: - if padding == 'replicate': - add_idx = 0 - elif padding == 'reflection': - add_idx = -i - elif padding == 'new_info': - add_idx = (crt_i + n_pad) + (-i) - elif padding == 'circle': - add_idx = N + i - else: - raise ValueError('Wrong padding mode') - elif i > max_n: - if padding == 'replicate': - add_idx = max_n - elif padding == 'reflection': - add_idx = max_n * 2 - i - elif padding == 'new_info': - add_idx = (crt_i - n_pad) - (i - max_n) - elif padding == 'circle': - add_idx = i - N - else: - raise ValueError('Wrong padding mode') - else: - add_idx = i - return_l.append(add_idx) - name_b = '{:08d}'.format(crt_i) - return return_l, name_b - - def __len__(self): - """Return the total number of images in the dataset. - """ - return len(self.filelist) diff --git a/ppgan/datasets/preprocess/__init__.py b/ppgan/datasets/preprocess/__init__.py index 1712224..6b73cbf 100644 --- a/ppgan/datasets/preprocess/__init__.py +++ b/ppgan/datasets/preprocess/__init__.py @@ -1,4 +1,4 @@ -from .io import LoadImageFromFile, ReadImageSequence, GetNeighboringFramesIdx +from .io import LoadImageFromFile, ReadImageSequence, GetNeighboringFramesIdx, GetFrameIdx, GetFrameIdxwithPadding from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip, PairedRandomVerticalFlip, PairedRandomTransposeHW, SRPairedRandomCrop, SplitPairedImage, SRNoise, diff --git a/ppgan/datasets/preprocess/io.py b/ppgan/datasets/preprocess/io.py index d8ce34e..e5123ed 100644 --- a/ppgan/datasets/preprocess/io.py +++ b/ppgan/datasets/preprocess/io.py @@ -179,3 +179,151 @@ class GetNeighboringFramesIdx: datas['interval'] = interval return datas + + +@PREPROCESS.register() +class GetFrameIdx: + """Generate frame index for REDS datasets. + + Args: + interval_list (list[int]): Interval list for temporal augmentation. + It will randomly pick an interval from interval_list and sample + frame index with the interval. + frames_per_clip(int): Number of frames per clips. Default: 99 for + REDS dataset. + """ + def __init__(self, interval_list, frames_per_clip=99): + self.interval_list = interval_list + self.frames_per_clip = frames_per_clip + + def __call__(self, results): + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + clip_name, frame_name = results['key'].split('/') + center_frame_idx = int(frame_name) + num_half_frames = results['num_frames'] // 2 + + interval = np.random.choice(self.interval_list) + # ensure not exceeding the borders + start_frame_idx = center_frame_idx - num_half_frames * interval + end_frame_idx = center_frame_idx + num_half_frames * interval + while (start_frame_idx < 0) or (end_frame_idx > self.frames_per_clip): + center_frame_idx = np.random.randint(0, self.frames_per_clip + 1) + start_frame_idx = center_frame_idx - num_half_frames * interval + end_frame_idx = center_frame_idx + num_half_frames * interval + frame_name = f'{center_frame_idx:08d}' + neighbor_list = list( + range(center_frame_idx - num_half_frames * interval, + center_frame_idx + num_half_frames * interval + 1, interval)) + + lq_path_root = results['lq_path'] + gt_path_root = results['gt_path'] + lq_path = [ + os.path.join(lq_path_root, clip_name, f'{v:08d}.png') + for v in neighbor_list + ] + gt_path = [os.path.join(gt_path_root, clip_name, f'{frame_name}.png')] + results['lq_path'] = lq_path + results['gt_path'] = gt_path + results['interval'] = interval + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(interval_list={self.interval_list}, ' + f'frames_per_clip={self.frames_per_clip})') + return repr_str + + +@PREPROCESS.register() +class GetFrameIdxwithPadding: + """Generate frame index with padding for REDS dataset and Vid4 dataset + during testing. + + Args: + padding (str): padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle'. + + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + filename_tmpl (str): Template for file name. Default: '{:08d}'. + """ + def __init__(self, padding, filename_tmpl='{:08d}'): + if padding not in ('replicate', 'reflection', 'reflection_circle', + 'circle'): + raise ValueError(f'Wrong padding mode {padding}.' + 'Should be "replicate", "reflection", ' + '"reflection_circle", "circle"') + self.padding = padding + self.filename_tmpl = filename_tmpl + + def __call__(self, results): + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + clip_name, frame_name = results['key'].split('/') + current_idx = int(frame_name) + max_frame_num = results['max_frame_num'] - 1 # start from 0 + num_frames = results['num_frames'] + num_pad = num_frames // 2 + + frame_list = [] + for i in range(current_idx - num_pad, current_idx + num_pad + 1): + if i < 0: + if self.padding == 'replicate': + pad_idx = 0 + elif self.padding == 'reflection': + pad_idx = -i + elif self.padding == 'reflection_circle': + pad_idx = current_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if self.padding == 'replicate': + pad_idx = max_frame_num + elif self.padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif self.padding == 'reflection_circle': + pad_idx = (current_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + frame_list.append(pad_idx) + + lq_path_root = results['lq_path'] + gt_path_root = results['gt_path'] + lq_paths = [ + os.path.join(lq_path_root, clip_name, + f'{self.filename_tmpl.format(idx)}.png') + for idx in frame_list + ] + gt_paths = [os.path.join(gt_path_root, clip_name, f'{frame_name}.png')] + results['lq_path'] = lq_paths + results['gt_path'] = gt_paths + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f"(padding='{self.padding}')" + return repr_str diff --git a/ppgan/datasets/sr_reds_multiple_gt_dataset.py b/ppgan/datasets/sr_reds_multiple_gt_dataset.py deleted file mode 100644 index cf59409..0000000 --- a/ppgan/datasets/sr_reds_multiple_gt_dataset.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import random -import numpy as np -import cv2 -from paddle.io import Dataset - -from .builder import DATASETS - -logger = logging.getLogger(__name__) - - -@DATASETS.register() -class SRREDSMultipleGTDataset(Dataset): - """REDS dataset for video super resolution for recurrent networks. - - The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) - frames. Then it applies specified transforms and finally returns a dict - containing paired data and other information. - - Args: - lq_folder (str | :obj:`Path`): Path to a lq folder. - gt_folder (str | :obj:`Path`): Path to a gt folder. - num_input_frames (int): Number of input frames. - pipeline (list[dict | callable]): A sequence of data transformations. - scale (int): Upsampling scale ratio. - val_partition (str): Validation partition mode. Choices ['official' or - 'REDS4']. Default: 'official'. - test_mode (bool): Store `True` when building test dataset. - Default: `False`. - """ - def __init__(self, - mode, - lq_folder, - gt_folder, - crop_size=256, - interval_list=[1], - random_reverse=False, - number_frames=15, - use_flip=False, - use_rot=False, - scale=4, - val_partition='REDS4', - batch_size=4, - num_clips=270): - super(SRREDSMultipleGTDataset, self).__init__() - self.mode = mode - self.fileroot = str(lq_folder) - self.gtroot = str(gt_folder) - self.crop_size = crop_size - self.interval_list = interval_list - self.random_reverse = random_reverse - self.number_frames = number_frames - self.use_flip = use_flip - self.use_rot = use_rot - self.scale = scale - self.val_partition = val_partition - self.batch_size = batch_size - self.num_clips = num_clips # training num of LQ and GT pairs - self.data_infos = self.load_annotations() - - def __getitem__(self, idx): - """Get item at each call. - - Args: - idx (int): Index for getting each item. - """ - item = self.data_infos[idx] - idt = random.randint(0, 100 - self.number_frames) - item = item + '_' + f'{idt:03d}' - img_LQs, img_GTs = self.get_sample_data( - item, self.number_frames, self.interval_list, self.random_reverse, - self.gtroot, self.fileroot, self.crop_size, self.scale, - self.use_flip, self.use_rot, self.mode) - return {'lq': img_LQs, 'gt': img_GTs, 'lq_path': self.data_infos[idx]} - - def load_annotations(self): - """Load annoations for REDS dataset. - - Returns: - dict: Returned dict for LQ and GT pairs. - """ - # generate keys - keys = [f'{i:03d}' for i in range(0, self.num_clips)] - - if self.val_partition == 'REDS4': - val_partition = ['000', '011', '015', '020'] - elif self.val_partition == 'official': - val_partition = [f'{i:03d}' for i in range(240, 270)] - else: - raise ValueError(f'Wrong validation partition {self.val_partition}.' - f'Supported ones are ["official", "REDS4"]') - - if self.mode == 'train': - keys = [v for v in keys if v not in val_partition] - else: - keys = [v for v in keys if v in val_partition] - - data_infos = [] - for key in keys: - data_infos.append(key) - - return data_infos - - def get_sample_data(self, - item, - number_frames, - interval_list, - random_reverse, - gtroot, - fileroot, - crop_size, - scale, - use_flip, - use_rot, - mode='train'): - video_name = item.split('_')[0] - frame_name = item.split('_')[1] - frame_idxs = self.get_neighbor_frames(frame_name, - number_frames=number_frames, - interval_list=interval_list, - random_reverse=random_reverse) - - frame_list = [] - gt_list = [] - for frame_idx in frame_idxs: - frame_idx_name = "%08d" % frame_idx - img = self.read_img( - os.path.join(fileroot, video_name, frame_idx_name + '.png')) - frame_list.append(img) - gt_img = self.read_img( - os.path.join(gtroot, video_name, frame_idx_name + '.png')) - gt_list.append(gt_img) - H, W, C = frame_list[0].shape - # add random crop - if (mode == 'train') or (mode == 'valid'): - LQ_size = crop_size // scale - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - frame_list = [ - v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] - for v in frame_list - ] - rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) - gt_list = [ - v[rnd_h_HR:rnd_h_HR + crop_size, - rnd_w_HR:rnd_w_HR + crop_size, :] for v in gt_list - ] - - # add random flip and rotation - for v in gt_list: - frame_list.append(v) - if (mode == 'train') or (mode == 'valid'): - rlt = self.img_augment(frame_list, use_flip, use_rot) - else: - rlt = frame_list - frame_list = rlt[0:number_frames] - gt_list = rlt[number_frames:] - - # stack LQ images to NHWC, N is the frame number - frame_list = [ - v.transpose(2, 0, 1).astype('float32') for v in frame_list - ] - gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list] - - img_LQs = np.stack(frame_list, axis=0) - img_GTs = np.stack(gt_list, axis=0) - - return img_LQs, img_GTs - - def get_neighbor_frames(self, frame_name, number_frames, interval_list, - random_reverse): - frame_idx = int(frame_name) - interval = random.choice(interval_list) - neighbor_list = list( - range(frame_idx, frame_idx + number_frames, interval)) - if random_reverse and random.random() < 0.5: - neighbor_list.reverse() - - assert len(neighbor_list) == number_frames, \ - "frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames) - - return neighbor_list - - def read_img(self, path, size=None): - """read image by cv2 - - return: Numpy float32, HWC, BGR, [0,1] - """ - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - img = img.astype(np.float32) / 255. - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - # some images have 4 channels - if img.shape[2] > 3: - img = img[:, :, :3] - return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - - def img_augment(self, img_list, hflip=True, rot=True): - """horizontal flip OR rotate (0, 90, 180, 270 degrees) - """ - hflip = hflip and random.random() < 0.5 - vflip = rot and random.random() < 0.5 - rot90 = rot and random.random() < 0.5 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - return img - - return [_augment(img) for img in img_list] - - def __len__(self): - """Length of the dataset. - - Returns: - int: Length of the dataset. - """ - return len(self.data_infos) diff --git a/ppgan/datasets/vsr_reds_dataset.py b/ppgan/datasets/vsr_reds_dataset.py new file mode 100644 index 0000000..9dc665e --- /dev/null +++ b/ppgan/datasets/vsr_reds_dataset.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from .builder import DATASETS +from .base_sr_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +@DATASETS.register() +class VSRREDSDataset(BaseDataset): + """REDS dataset for video super resolution for Sliding-window networks. + + The dataset loads several LQ (Low-Quality) frames and a center GT + (Ground-Truth) frame. Then it applies specified transforms and finally + returns a dict containing paired data and other information. + + It reads REDS keys from the txt file. Each line contains video frame folder + + Examples: + + 000/00000000.png (720, 1280, 3) + 000/00000001.png (720, 1280, 3) + + Args: + lq_folder (str): Path to a low quality image folder. + gt_folder (str): Path to a ground truth image folder. + ann_file (str): Path to the annotation file. + num_frames (int): Window size for input frames. + preprocess (list[dict|callable]): A list functions of data transformations. + val_partition (str): Validation partition mode. Choices ['official' or 'REDS4']. Default: 'REDS4'. + test_mode (bool): Store `True` when building test dataset. Default: `False`. + """ + def __init__(self, + lq_folder, + gt_folder, + ann_file, + num_frames, + preprocess, + val_partition='REDS4', + test_mode=False): + super().__init__(preprocess) + assert num_frames % 2 == 1, (f'num_frames should be odd numbers, ' + f'but received {num_frames }.') + self.lq_folder = str(lq_folder) + self.gt_folder = str(gt_folder) + self.ann_file = str(ann_file) + self.num_frames = num_frames + self.val_partition = val_partition + self.test_mode = test_mode + self.data_infos = self.prepare_data_infos() + + def prepare_data_infos(self): + """Load annoations for REDS dataset. + Returns: + dict: Returned dict for LQ and GT pairs. + """ + # get keys + with open(self.ann_file, 'r') as fin: + keys = [v.strip().split('.')[0] for v in fin] + + if self.val_partition == 'REDS4': + val_partition = ['000', '011', '015', '020'] + elif self.val_partition == 'official': + val_partition = [f'{v:03d}' for v in range(240, 270)] + else: + raise ValueError(f'Wrong validation partition {self.val_partition}.' + f'Supported ones are ["official", "REDS4"]') + + if self.test_mode: + keys = [v for v in keys if v.split('/')[0] in val_partition] + else: + keys = [v for v in keys if v.split('/')[0] not in val_partition] + + data_infos = [] + for key in keys: + data_infos.append( + dict(lq_path=self.lq_folder, + gt_path=self.gt_folder, + key=key, + max_frame_num=100, + num_frames=self.num_frames)) + + return data_infos diff --git a/ppgan/datasets/vsr_reds_multiple_gt_dataset.py b/ppgan/datasets/vsr_reds_multiple_gt_dataset.py new file mode 100644 index 0000000..f327b10 --- /dev/null +++ b/ppgan/datasets/vsr_reds_multiple_gt_dataset.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from .builder import DATASETS +from .base_sr_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +@DATASETS.register() +class VSRREDSMultipleGTDataset(BaseDataset): + """REDS dataset for video super resolution for recurrent networks. + + The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) frames. + Then it applies specified transforms and finally returns a dict containing + paired data and other information. + + Args: + Args: + lq_folder (str): Path to a low quality image folder. + gt_folder (str): Path to a ground truth image folder. + ann_file (str): Path to the annotation file. + num_frames (int): Window size for input frames. + preprocess (list[dict|callable]): A list functions of data transformations. + val_partition (str): Validation partition mode. Choices ['official' or 'REDS4']. + Default: 'REDS4'. + test_mode (bool): Store `True` when building test dataset. Default: `False`. + """ + def __init__(self, + lq_folder, + gt_folder, + ann_file, + num_frames, + preprocess, + val_partition='REDS4', + test_mode=False): + super().__init__(preprocess) + self.lq_folder = str(lq_folder) + self.gt_folder = str(gt_folder) + self.ann_file = str(ann_file) + self.num_frames = num_frames + self.val_partition = val_partition + self.test_mode = test_mode + self.data_infos = self.prepare_data_infos() + + def prepare_data_infos(self): + """Load annoations for REDS dataset. + + Returns: + dict: Returned dict for LQ and GT pairs. + """ + # get keys + with open(self.ann_file, 'r') as fin: + keys = [v.strip().split('/')[0] for v in fin] + keys = list(set(keys)) + + if self.val_partition == 'REDS4': + val_partition = ['000', '011', '015', '020'] + elif self.val_partition == 'official': + val_partition = [f'{v:03d}' for v in range(240, 270)] + else: + raise ValueError(f'Wrong validation partition {self.val_partition}.' + f'Supported ones are ["official", "REDS4"]') + + if self.test_mode: + keys = [v for v in keys if v in val_partition] + else: + keys = [v for v in keys if v not in val_partition] + + data_infos = [] + for key in keys: + data_infos.append( + dict(lq_path=self.lq_folder, + gt_path=self.gt_folder, + key=key, + sequence_length=100, + num_frames=self.num_frames)) + + return data_infos diff --git a/ppgan/models/edvr_model.py b/ppgan/models/edvr_model.py index 3fa270d..b95e8f1 100644 --- a/ppgan/models/edvr_model.py +++ b/ppgan/models/edvr_model.py @@ -48,7 +48,7 @@ class EDVRModel(BaseSRModel): self.visual_items['lq+1'] = self.lq[:, 3, :, :, :] self.visual_items['lq+2'] = self.lq[:, 4, :, :, :] if 'gt' in input: - self.gt = input['gt'] + self.gt = input['gt'][:, 0, :, :, :] self.visual_items['gt'] = self.gt self.image_paths = input['lq_path'] -- GitLab