未验证 提交 424ab9eb 编写于 作者: W wangna11BD 提交者: GitHub

fix reds dataset (#527)

* fix reds dataset

* fix CI
上级 a8ae6b7a
......@@ -27,33 +27,61 @@ dataset:
num_workers: 4
batch_size: 2 #4 gpus
dataset:
name: SRREDSMultipleGTDataset
mode: train
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 30
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 30
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRREDSMultipleGTDataset
mode: test
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 100
test_mode: True
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -53,7 +53,7 @@ dataset:
keys: [image, image]
- name: MirrorVideoSequence
- name: NormalizeSequence
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......@@ -80,7 +80,7 @@ dataset:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......
......@@ -24,37 +24,64 @@ dataset:
name: RepeatDataset
times: 1000
num_workers: 4
batch_size: 2 #4 GPUs
batch_size: 2 #4 gpus
dataset:
name: SRREDSMultipleGTDataset
mode: train
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 15
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
num_clips: 270
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 15
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRREDSMultipleGTDataset
mode: test
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
num_clips: 270
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 100
test_mode: True
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -20,47 +20,79 @@ model:
front_RBs: 5
back_RBs: 40
center: 2
predeblur: True #False
HR_in: True #False
predeblur: True
HR_in: True
w_TSA: True
pixel_criterion:
name: CharbonnierLoss
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_blur/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 1
fix_random_seed: 10
name: RepeatDataset
times: 1000
num_workers: 6
batch_size: 8
batch_size: 8 #4 gpus
dataset:
name: VSRREDSDataset
lq_folder: data/REDS/train_blur/X4
gt_folder: data/REDS/train_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: False
preprocess:
- name: GetFrameIdx
interval_list: [1]
frames_per_clip: 99
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 1
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
name: VSRREDSDataset
lq_folder: data/REDS/REDS4_test_blur/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 1
fix_random_seed: 10
gt_folder: data/REDS/REDS4_test_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: True
preprocess:
- name: GetFrameIdxwithPadding
padding: reflection_circle
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -19,47 +19,79 @@ model:
front_RBs: 5
back_RBs: 40
center: 2
predeblur: True #False
HR_in: True #False
predeblur: True
HR_in: True
w_TSA: False
pixel_criterion:
name: CharbonnierLoss
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_blur/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 1
fix_random_seed: 10
name: RepeatDataset
times: 1000
num_workers: 6
batch_size: 8
batch_size: 8 #4 gpus
dataset:
name: VSRREDSDataset
lq_folder: data/REDS/train_blur/X4
gt_folder: data/REDS/train_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: False
preprocess:
- name: GetFrameIdx
interval_list: [1]
frames_per_clip: 99
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 1
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
name: VSRREDSDataset
lq_folder: data/REDS/REDS4_test_blur/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 1
fix_random_seed: 10
gt_folder: data/REDS/REDS4_test_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: True
preprocess:
- name: GetFrameIdxwithPadding
padding: reflection_circle
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -28,39 +28,71 @@ model:
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_sharp_bicubic/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 4
fix_random_seed: 10
name: RepeatDataset
times: 1000
num_workers: 3
batch_size: 4 # 8GPUs
batch_size: 4 #8 gpus
dataset:
name: VSRREDSDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: False
preprocess:
- name: GetFrameIdx
interval_list: [1]
frames_per_clip: 99
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
name: VSRREDSDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 4
fix_random_seed: 10
gt_folder: data/REDS/REDS4_test_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: True
preprocess:
- name: GetFrameIdxwithPadding
padding: reflection_circle
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -27,39 +27,71 @@ model:
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_sharp_bicubic/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 4
fix_random_seed: 10
name: RepeatDataset
times: 1000
num_workers: 3
batch_size: 4 # 8GPUs
batch_size: 4 #8 gpus
dataset:
name: VSRREDSDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: False
preprocess:
- name: GetFrameIdx
interval_list: [1]
frames_per_clip: 99
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
name: VSRREDSDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 4
fix_random_seed: 10
gt_folder: data/REDS/REDS4_test_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: True
preprocess:
- name: GetFrameIdxwithPadding
padding: reflection_circle
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -31,39 +31,71 @@ export_model:
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_sharp_bicubic/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 4
fix_random_seed: 10
name: RepeatDataset
times: 1000
num_workers: 3
batch_size: 4 # 8GPUs
batch_size: 4 #8 gpus
dataset:
name: VSRREDSDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: False
preprocess:
- name: GetFrameIdx
interval_list: [1]
frames_per_clip: 99
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
name: VSRREDSDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 4
fix_random_seed: 10
gt_folder: data/REDS/REDS4_test_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: True
preprocess:
- name: GetFrameIdxwithPadding
padding: reflection_circle
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -27,39 +27,71 @@ model:
dataset:
train:
name: REDSDataset
mode: train
gt_folder: data/REDS/train_sharp/X4
lq_folder: data/REDS/train_sharp_bicubic/X4
img_format: png
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
use_flip: True
use_rot: True
buf_size: 1024
scale: 4
fix_random_seed: 10
name: RepeatDataset
times: 1000
num_workers: 3
batch_size: 4 # 8GPUs
batch_size: 4 #8 gpus
dataset:
name: VSRREDSDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: False
preprocess:
- name: GetFrameIdx
interval_list: [1]
frames_per_clip: 99
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: REDSDataset
mode: test
gt_folder: data/REDS/REDS4_test_sharp/X4
name: VSRREDSDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
img_format: png
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
use_flip: False
use_rot: False
buf_size: 1024
scale: 4
fix_random_seed: 10
gt_folder: data/REDS/REDS4_test_sharp/X4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 5
val_partition: REDS4
test_mode: True
preprocess:
- name: GetFrameIdxwithPadding
padding: reflection_circle
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -44,7 +44,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
......@@ -63,7 +63,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......
......@@ -64,7 +64,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
......@@ -83,7 +83,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......
......@@ -24,35 +24,63 @@ dataset:
name: RepeatDataset
times: 1000
num_workers: 4
batch_size: 2 #4 GPUs
batch_size: 2 #4 gpus
dataset:
name: SRREDSMultipleGTDataset
mode: train
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 15
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 15
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRREDSMultipleGTDataset
mode: test
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 100
test_mode: True
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -40,7 +40,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
......@@ -59,7 +59,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......
......@@ -34,35 +34,62 @@ dataset:
times: 1000
num_workers: 4
batch_size: 2 #8 gpus
use_shared_memory: True
dataset:
name: SRREDSMultipleGTDataset
mode: train
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 30
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 30
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRREDSMultipleGTDataset
mode: test
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 100
test_mode: True
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -34,37 +34,62 @@ dataset:
times: 1000
num_workers: 6
batch_size: 2 #8 gpus
use_shared_memory: True
dataset:
name: SRREDSMultipleGTDataset
mode: train
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 20
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
num_clips: 270
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 20
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRREDSMultipleGTDataset
mode: test
name: VSRREDSMultipleGTDataset
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
num_clips: 270
ann_file: data/REDS/meta_info_REDS_GT.txt
num_frames: 100
test_mode: True
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
......@@ -104,4 +129,4 @@ snapshot_config:
interval: 5000
export_model:
- {name: 'generator', inputs_num: 1}
\ No newline at end of file
- {name: 'generator', inputs_num: 1}
......@@ -36,7 +36,6 @@ dataset:
batch_size: 2 #8 gpus
dataset:
name: VSRVimeo90KDataset
# mode: train
lq_folder: data/vimeo90k/vimeo_septuplet_BD_matlabLRx4/sequences
gt_folder: data/vimeo90k/vimeo_septuplet/sequences
ann_file: data/vimeo90k/vimeo_septuplet/sep_trainlist.txt
......@@ -62,7 +61,7 @@ dataset:
keys: [image, image]
- name: MirrorVideoSequence
- name: NormalizeSequence
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......@@ -89,7 +88,7 @@ dataset:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......
......@@ -61,7 +61,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
- name: SRNoise
......@@ -84,7 +84,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......
......@@ -61,7 +61,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
- name: SRNoise
......@@ -84,7 +84,7 @@ dataset:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
mean: [0., 0., 0.]
std: [255., 255., 255.]
keys: [image, image]
......
......@@ -21,11 +21,11 @@ from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset
from .wav2lip_dataset import Wav2LipDataset
from .starganv2_dataset import StarGANv2Dataset
from .edvr_dataset import REDSDataset
from .firstorder_dataset import FirstOrderDataset
from .lapstyle_dataset import LapStyleDataset
from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset
from .mpr_dataset import MPRTrain, MPRVal, MPRTest
from .vsr_reds_dataset import VSRREDSDataset
from .vsr_reds_multiple_gt_dataset import VSRREDSMultipleGTDataset
from .vsr_vimeo90k_dataset import VSRVimeo90KDataset
from .vsr_folder_dataset import VSRFolderDataset
from .photopen_dataset import PhotoPenDataset
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import numpy as np
import scipy.io as scio
import cv2
import paddle
from paddle.io import Dataset, DataLoader
from .builder import DATASETS
logger = logging.getLogger(__name__)
@DATASETS.register()
class REDSDataset(Dataset):
"""
REDS dataset for EDVR model
"""
def __init__(self,
mode,
lq_folder,
gt_folder,
img_format="png",
crop_size=256,
interval_list=[1],
random_reverse=False,
number_frames=5,
batch_size=32,
use_flip=False,
use_rot=False,
buf_size=1024,
scale=4,
fix_random_seed=False):
super(REDSDataset, self).__init__()
self.format = img_format
self.mode = mode
self.crop_size = crop_size
self.interval_list = interval_list
self.random_reverse = random_reverse
self.number_frames = number_frames
self.batch_size = batch_size
self.fileroot = lq_folder
self.use_flip = use_flip
self.use_rot = use_rot
self.buf_size = buf_size
self.fix_random_seed = fix_random_seed
if self.mode != 'infer':
self.gtroot = gt_folder
self.scale = scale
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(10)
np.random.seed(10)
self.num_reader_threads = 1
self._init_()
def _init_(self):
logger.info('initialize reader ... ')
print("initialize reader")
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in [
'000', '011', '015', '020'
]): #These four videos are used as val
continue
for frame_name in os.listdir(os.path.join(self.fileroot,
video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + str(frame_idx)
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.mode == 'test':
self.filelist.sort()
print(len(self.filelist))
def __getitem__(self, index):
"""Get training sample
return: lq:[5,3,W,H],
gt:[3,W,H],
lq_path:str
"""
item = self.filelist[index]
img_LQs, img_GT = self.get_sample_data(
item, self.number_frames, self.interval_list, self.random_reverse,
self.gtroot, self.fileroot, self.LR_input, self.crop_size,
self.scale, self.use_flip, self.use_rot, self.mode)
return {'lq': img_LQs, 'gt': img_GT, 'lq_path': self.filelist[index]}
def get_sample_data(self,
item,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
ngb_frames, name_b = self.get_neighbor_frames(frame_name, \
number_frames=number_frames, \
interval_list=interval_list, \
random_reverse=random_reverse)
elif mode == 'test':
ngb_frames, name_b = self.get_test_neighbor_frames(
int(frame_name), number_frames)
else:
raise NotImplementedError('mode {} not implemented'.format(mode))
frame_name = name_b
img_GT = self.read_img(
os.path.join(gtroot, video_name, frame_name + '.png'))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d" % ngb_frm
img = self.read_img(
os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
frame_list = [
v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
for v in frame_list
]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size,
rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [
v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
for v in frame_list
]
img_GT = img_GT[rnd_h:rnd_h + crop_size,
rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
if (mode == 'train') or (mode == 'valid'):
rlt = self.img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, img_GT
def get_neighbor_frames(self,
frame_name,
number_frames,
interval_list,
random_reverse,
max_frame=99,
bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
interval = random.choice(interval_list)
if bordermode:
direction = 1
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > max_frame:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
if direction == 1:
neighbor_list = list(
range(center_frame_idx,
center_frame_idx + interval * number_frames,
interval))
else:
neighbor_list = list(
range(center_frame_idx,
center_frame_idx - interval * number_frames,
-interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval > max_frame) or (
center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, max_frame)
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1,
interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(self, path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(self, img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def get_test_neighbor_frames(self, crt_i, N, max_n=100, padding='new_info'):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
name_b = '{:08d}'.format(crt_i)
return return_l, name_b
def __len__(self):
"""Return the total number of images in the dataset.
"""
return len(self.filelist)
from .io import LoadImageFromFile, ReadImageSequence, GetNeighboringFramesIdx
from .io import LoadImageFromFile, ReadImageSequence, GetNeighboringFramesIdx, GetFrameIdx, GetFrameIdxwithPadding
from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip,
PairedRandomVerticalFlip, PairedRandomTransposeHW,
SRPairedRandomCrop, SplitPairedImage, SRNoise,
......
......@@ -179,3 +179,151 @@ class GetNeighboringFramesIdx:
datas['interval'] = interval
return datas
@PREPROCESS.register()
class GetFrameIdx:
"""Generate frame index for REDS datasets.
Args:
interval_list (list[int]): Interval list for temporal augmentation.
It will randomly pick an interval from interval_list and sample
frame index with the interval.
frames_per_clip(int): Number of frames per clips. Default: 99 for
REDS dataset.
"""
def __init__(self, interval_list, frames_per_clip=99):
self.interval_list = interval_list
self.frames_per_clip = frames_per_clip
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
clip_name, frame_name = results['key'].split('/')
center_frame_idx = int(frame_name)
num_half_frames = results['num_frames'] // 2
interval = np.random.choice(self.interval_list)
# ensure not exceeding the borders
start_frame_idx = center_frame_idx - num_half_frames * interval
end_frame_idx = center_frame_idx + num_half_frames * interval
while (start_frame_idx < 0) or (end_frame_idx > self.frames_per_clip):
center_frame_idx = np.random.randint(0, self.frames_per_clip + 1)
start_frame_idx = center_frame_idx - num_half_frames * interval
end_frame_idx = center_frame_idx + num_half_frames * interval
frame_name = f'{center_frame_idx:08d}'
neighbor_list = list(
range(center_frame_idx - num_half_frames * interval,
center_frame_idx + num_half_frames * interval + 1, interval))
lq_path_root = results['lq_path']
gt_path_root = results['gt_path']
lq_path = [
os.path.join(lq_path_root, clip_name, f'{v:08d}.png')
for v in neighbor_list
]
gt_path = [os.path.join(gt_path_root, clip_name, f'{frame_name}.png')]
results['lq_path'] = lq_path
results['gt_path'] = gt_path
results['interval'] = interval
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(interval_list={self.interval_list}, '
f'frames_per_clip={self.frames_per_clip})')
return repr_str
@PREPROCESS.register()
class GetFrameIdxwithPadding:
"""Generate frame index with padding for REDS dataset and Vid4 dataset
during testing.
Args:
padding (str): padding mode, one of
'replicate' | 'reflection' | 'reflection_circle' | 'circle'.
Examples: current_idx = 0, num_frames = 5
The generated frame indices under different padding mode:
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
reflection_circle: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
filename_tmpl (str): Template for file name. Default: '{:08d}'.
"""
def __init__(self, padding, filename_tmpl='{:08d}'):
if padding not in ('replicate', 'reflection', 'reflection_circle',
'circle'):
raise ValueError(f'Wrong padding mode {padding}.'
'Should be "replicate", "reflection", '
'"reflection_circle", "circle"')
self.padding = padding
self.filename_tmpl = filename_tmpl
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
clip_name, frame_name = results['key'].split('/')
current_idx = int(frame_name)
max_frame_num = results['max_frame_num'] - 1 # start from 0
num_frames = results['num_frames']
num_pad = num_frames // 2
frame_list = []
for i in range(current_idx - num_pad, current_idx + num_pad + 1):
if i < 0:
if self.padding == 'replicate':
pad_idx = 0
elif self.padding == 'reflection':
pad_idx = -i
elif self.padding == 'reflection_circle':
pad_idx = current_idx + num_pad - i
else:
pad_idx = num_frames + i
elif i > max_frame_num:
if self.padding == 'replicate':
pad_idx = max_frame_num
elif self.padding == 'reflection':
pad_idx = max_frame_num * 2 - i
elif self.padding == 'reflection_circle':
pad_idx = (current_idx - num_pad) - (i - max_frame_num)
else:
pad_idx = i - num_frames
else:
pad_idx = i
frame_list.append(pad_idx)
lq_path_root = results['lq_path']
gt_path_root = results['gt_path']
lq_paths = [
os.path.join(lq_path_root, clip_name,
f'{self.filename_tmpl.format(idx)}.png')
for idx in frame_list
]
gt_paths = [os.path.join(gt_path_root, clip_name, f'{frame_name}.png')]
results['lq_path'] = lq_paths
results['gt_path'] = gt_paths
return results
def __repr__(self):
repr_str = self.__class__.__name__ + f"(padding='{self.padding}')"
return repr_str
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import numpy as np
import cv2
from paddle.io import Dataset
from .builder import DATASETS
logger = logging.getLogger(__name__)
@DATASETS.register()
class SRREDSMultipleGTDataset(Dataset):
"""REDS dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
num_input_frames (int): Number of input frames.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
val_partition (str): Validation partition mode. Choices ['official' or
'REDS4']. Default: 'official'.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""
def __init__(self,
mode,
lq_folder,
gt_folder,
crop_size=256,
interval_list=[1],
random_reverse=False,
number_frames=15,
use_flip=False,
use_rot=False,
scale=4,
val_partition='REDS4',
batch_size=4,
num_clips=270):
super(SRREDSMultipleGTDataset, self).__init__()
self.mode = mode
self.fileroot = str(lq_folder)
self.gtroot = str(gt_folder)
self.crop_size = crop_size
self.interval_list = interval_list
self.random_reverse = random_reverse
self.number_frames = number_frames
self.use_flip = use_flip
self.use_rot = use_rot
self.scale = scale
self.val_partition = val_partition
self.batch_size = batch_size
self.num_clips = num_clips # training num of LQ and GT pairs
self.data_infos = self.load_annotations()
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
item = self.data_infos[idx]
idt = random.randint(0, 100 - self.number_frames)
item = item + '_' + f'{idt:03d}'
img_LQs, img_GTs = self.get_sample_data(
item, self.number_frames, self.interval_list, self.random_reverse,
self.gtroot, self.fileroot, self.crop_size, self.scale,
self.use_flip, self.use_rot, self.mode)
return {'lq': img_LQs, 'gt': img_GTs, 'lq_path': self.data_infos[idx]}
def load_annotations(self):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys = [f'{i:03d}' for i in range(0, self.num_clips)]
if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif self.val_partition == 'official':
val_partition = [f'{i:03d}' for i in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {self.val_partition}.'
f'Supported ones are ["official", "REDS4"]')
if self.mode == 'train':
keys = [v for v in keys if v not in val_partition]
else:
keys = [v for v in keys if v in val_partition]
data_infos = []
for key in keys:
data_infos.append(key)
return data_infos
def get_sample_data(self,
item,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
frame_idxs = self.get_neighbor_frames(frame_name,
number_frames=number_frames,
interval_list=interval_list,
random_reverse=random_reverse)
frame_list = []
gt_list = []
for frame_idx in frame_idxs:
frame_idx_name = "%08d" % frame_idx
img = self.read_img(
os.path.join(fileroot, video_name, frame_idx_name + '.png'))
frame_list.append(img)
gt_img = self.read_img(
os.path.join(gtroot, video_name, frame_idx_name + '.png'))
gt_list.append(gt_img)
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
frame_list = [
v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
for v in frame_list
]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
gt_list = [
v[rnd_h_HR:rnd_h_HR + crop_size,
rnd_w_HR:rnd_w_HR + crop_size, :] for v in gt_list
]
# add random flip and rotation
for v in gt_list:
frame_list.append(v)
if (mode == 'train') or (mode == 'valid'):
rlt = self.img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:number_frames]
gt_list = rlt[number_frames:]
# stack LQ images to NHWC, N is the frame number
frame_list = [
v.transpose(2, 0, 1).astype('float32') for v in frame_list
]
gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list]
img_LQs = np.stack(frame_list, axis=0)
img_GTs = np.stack(gt_list, axis=0)
return img_LQs, img_GTs
def get_neighbor_frames(self, frame_name, number_frames, interval_list,
random_reverse):
frame_idx = int(frame_name)
interval = random.choice(interval_list)
neighbor_list = list(
range(frame_idx, frame_idx + number_frames, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list
def read_img(self, path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def img_augment(self, img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return len(self.data_infos)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from .builder import DATASETS
from .base_sr_dataset import BaseDataset
logger = logging.getLogger(__name__)
@DATASETS.register()
class VSRREDSDataset(BaseDataset):
"""REDS dataset for video super resolution for Sliding-window networks.
The dataset loads several LQ (Low-Quality) frames and a center GT
(Ground-Truth) frame. Then it applies specified transforms and finally
returns a dict containing paired data and other information.
It reads REDS keys from the txt file. Each line contains video frame folder
Examples:
000/00000000.png (720, 1280, 3)
000/00000001.png (720, 1280, 3)
Args:
lq_folder (str): Path to a low quality image folder.
gt_folder (str): Path to a ground truth image folder.
ann_file (str): Path to the annotation file.
num_frames (int): Window size for input frames.
preprocess (list[dict|callable]): A list functions of data transformations.
val_partition (str): Validation partition mode. Choices ['official' or 'REDS4']. Default: 'REDS4'.
test_mode (bool): Store `True` when building test dataset. Default: `False`.
"""
def __init__(self,
lq_folder,
gt_folder,
ann_file,
num_frames,
preprocess,
val_partition='REDS4',
test_mode=False):
super().__init__(preprocess)
assert num_frames % 2 == 1, (f'num_frames should be odd numbers, '
f'but received {num_frames }.')
self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.ann_file = str(ann_file)
self.num_frames = num_frames
self.val_partition = val_partition
self.test_mode = test_mode
self.data_infos = self.prepare_data_infos()
def prepare_data_infos(self):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# get keys
with open(self.ann_file, 'r') as fin:
keys = [v.strip().split('.')[0] for v in fin]
if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif self.val_partition == 'official':
val_partition = [f'{v:03d}' for v in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {self.val_partition}.'
f'Supported ones are ["official", "REDS4"]')
if self.test_mode:
keys = [v for v in keys if v.split('/')[0] in val_partition]
else:
keys = [v for v in keys if v.split('/')[0] not in val_partition]
data_infos = []
for key in keys:
data_infos.append(
dict(lq_path=self.lq_folder,
gt_path=self.gt_folder,
key=key,
max_frame_num=100,
num_frames=self.num_frames))
return data_infos
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from .builder import DATASETS
from .base_sr_dataset import BaseDataset
logger = logging.getLogger(__name__)
@DATASETS.register()
class VSRREDSMultipleGTDataset(BaseDataset):
"""REDS dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) frames.
Then it applies specified transforms and finally returns a dict containing
paired data and other information.
Args:
Args:
lq_folder (str): Path to a low quality image folder.
gt_folder (str): Path to a ground truth image folder.
ann_file (str): Path to the annotation file.
num_frames (int): Window size for input frames.
preprocess (list[dict|callable]): A list functions of data transformations.
val_partition (str): Validation partition mode. Choices ['official' or 'REDS4'].
Default: 'REDS4'.
test_mode (bool): Store `True` when building test dataset. Default: `False`.
"""
def __init__(self,
lq_folder,
gt_folder,
ann_file,
num_frames,
preprocess,
val_partition='REDS4',
test_mode=False):
super().__init__(preprocess)
self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.ann_file = str(ann_file)
self.num_frames = num_frames
self.val_partition = val_partition
self.test_mode = test_mode
self.data_infos = self.prepare_data_infos()
def prepare_data_infos(self):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# get keys
with open(self.ann_file, 'r') as fin:
keys = [v.strip().split('/')[0] for v in fin]
keys = list(set(keys))
if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif self.val_partition == 'official':
val_partition = [f'{v:03d}' for v in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {self.val_partition}.'
f'Supported ones are ["official", "REDS4"]')
if self.test_mode:
keys = [v for v in keys if v in val_partition]
else:
keys = [v for v in keys if v not in val_partition]
data_infos = []
for key in keys:
data_infos.append(
dict(lq_path=self.lq_folder,
gt_path=self.gt_folder,
key=key,
sequence_length=100,
num_frames=self.num_frames))
return data_infos
......@@ -48,7 +48,7 @@ class EDVRModel(BaseSRModel):
self.visual_items['lq+1'] = self.lq[:, 3, :, :, :]
self.visual_items['lq+2'] = self.lq[:, 4, :, :, :]
if 'gt' in input:
self.gt = input['gt']
self.gt = input['gt'][:, 0, :, :, :]
self.visual_items['gt'] = self.gt
self.image_paths = input['lq_path']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册