diff --git a/configs/cyclegan_cityscapes.yaml b/configs/cyclegan_cityscapes.yaml index b1a93786319b34e53f0336be4e836dd8417e4af8..e4e40208e876be1cf4ea88878eded500724d0343 100644 --- a/configs/cyclegan_cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -1,8 +1,5 @@ epochs: 200 output_dir: output_dir -lambda_A: 10.0 -lambda_B: 10.0 -lambda_identity: 0.5 model: name: CycleGANModel @@ -20,60 +17,96 @@ model: n_layers: 3 norm_type: instance input_nc: 3 - gan_mode: lsgan + cycle_criterion: + name: L1Loss + idt_criterion: + name: L1Loss + loss_weight: 0.5 + gan_criterion: + name: GANLoss + gan_mode: lsgan dataset: train: name: UnpairedDataset - dataroot: data/cityscapes + dataroot_a: data/cityscapes/trainA + dataroot_b: data/cityscapes/trainB num_workers: 0 batch_size: 1 - phase: train - max_dataset_size: inf - direction: AtoB - input_nc: 3 - output_nc: 3 - serial_batches: False - pool_size: 50 - transforms: - - name: Resize - size: [286, 286] - interpolation: 'bicubic' #cv2.INTER_CUBIC - - name: RandomCrop - size: [256, 256] - - name: RandomHorizontalFlip - prob: 0.5 - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] + is_train: True + max_size: inf + preprocess: + - name: LoadImageFromFile + key: A + - name: LoadImageFromFile + key: B + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [286, 286] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: ['image', 'image'] + - name: RandomCrop + size: [256, 256] + keys: ['image', 'image'] + - name: RandomHorizontalFlip + prob: 0.5 + keys: ['image', 'image'] + - name: Transpose + keys: ['image', 'image'] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: ['image', 'image'] test: - name: SingleDataset - dataroot: data/cityscapes/testB - max_dataset_size: inf - direction: BtoA - input_nc: 3 - output_nc: 3 - serial_batches: False - pool_size: 50 - transforms: - - name: Resize - size: [256, 256] - interpolation: 'bicubic' #cv2.INTER_CUBIC - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - -optimizer: - name: Adam - beta1: 0.5 + name: UnpairedDataset + dataroot_a: data/cityscapes/testA + dataroot_b: data/cityscapes/testB + num_workers: 0 + batch_size: 1 + max_size: inf + is_train: False + load_pipeline: + - name: LoadImageFromFile + key: A + - name: LoadImageFromFile + key: B + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [256, 256] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: ['image', 'image'] + - name: Transpose + keys: ['image', 'image'] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: ['image', 'image'] lr_scheduler: - name: linear + name: LinearDecay learning_rate: 0.0002 start_epoch: 100 decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimG: + name: Adam + net_names: + - netG_A + - netG_B + beta1: 0.5 + optimD: + name: Adam + net_names: + - netD_A + - netD_B + beta1: 0.5 log_config: interval: 100 diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index e4b68166ca543f962976878bd2ed305ebcf9deb8..924c8e4b284377084eec3ea5d5c456c4cf093398 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -1,8 +1,5 @@ epochs: 200 output_dir: output_dir -lambda_A: 10.0 -lambda_B: 10.0 -lambda_identity: 0.5 model: name: CycleGANModel @@ -20,60 +17,96 @@ model: n_layers: 3 norm_type: instance input_nc: 3 - gan_mode: lsgan + cycle_criterion: + name: L1Loss + idt_criterion: + name: L1Loss + loss_weight: 0.5 + gan_criterion: + name: GANLoss + gan_mode: lsgan dataset: train: name: UnpairedDataset - dataroot: data/horse2zebra + dataroot_a: data/horse2zebra/trainA + dataroot_b: data/horse2zebra/trainB num_workers: 0 batch_size: 1 - phase: train - max_dataset_size: inf - direction: AtoB - input_nc: 3 - output_nc: 3 - serial_batches: False - pool_size: 50 - transforms: - - name: Resize - size: [286, 286] - interpolation: 'bicubic' #cv2.INTER_CUBIC - - name: RandomCrop - size: [256, 256] - - name: RandomHorizontalFlip - prob: 0.5 - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] + is_train: True + max_size: inf + load_pipeline: + - name: LoadImageFromFile + key: A + - name: LoadImageFromFile + key: B + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [286, 286] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: ['image', 'image'] + - name: RandomCrop + size: [256, 256] + keys: ['image', 'image'] + - name: RandomHorizontalFlip + prob: 0.5 + keys: ['image', 'image'] + - name: Transpose + keys: ['image', 'image'] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: ['image', 'image'] test: - name: SingleDataset - dataroot: data/horse2zebra/testA - max_dataset_size: inf - direction: AtoB - input_nc: 3 - output_nc: 3 - serial_batches: False - pool_size: 50 - transforms: - - name: Resize - size: [256, 256] - interpolation: 'bicubic' #cv2.INTER_CUBIC - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - -optimizer: - name: Adam - beta1: 0.5 + name: UnpairedDataset + dataroot_a: data/horse2zebra/testA + dataroot_b: data/horse2zebra/testB + num_workers: 0 + batch_size: 1 + max_size: inf + is_train: False + load_pipeline: + - name: LoadImageFromFile + key: A + - name: LoadImageFromFile + key: B + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [256, 256] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: ['image', 'image'] + - name: Transpose + keys: ['image', 'image'] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: ['image', 'image'] lr_scheduler: - name: linear + name: LinearDecay learning_rate: 0.0002 start_epoch: 100 decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimG: + name: Adam + net_names: + - netG_A + - netG_B + beta1: 0.5 + optimD: + name: Adam + net_names: + - netD_A + - netD_B + beta1: 0.5 log_config: interval: 100 diff --git a/configs/dcgan_mnist.yaml b/configs/dcgan_mnist.yaml index 89a2ad4fb4d0f4711ed9f0f3c01b8de1191cb375..a423dd03657b6fb157cbc9a1d6d11edee06aaa6a 100644 --- a/configs/dcgan_mnist.yaml +++ b/configs/dcgan_mnist.yaml @@ -15,52 +15,70 @@ model: norm_type: batch ndf: 64 input_nc: 1 - gan_mode: vanilla #wgangp + gan_criterion: + name: GANLoss + gan_mode: vanilla dataset: train: name: SingleDataset dataroot: data/mnist/train - phase: train - max_dataset_size: inf - direction: AtoB - input_nc: 1 - output_nc: 1 batch_size: 128 - serial_batches: False - transforms: - - name: Resize - size: [64, 64] - interpolation: 'bicubic' #cv2.INTER_CUBIC - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] + preprocess: + - name: LoadImageFromFile + key: A + - name: Transfroms + input_keys: [A] + pipeline: + - name: Resize + size: [64, 64] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: SingleDataset dataroot: data/mnist/test - max_dataset_size: inf - input_nc: 1 - output_nc: 1 - serial_batches: False - transforms: - - name: Resize - size: [64, 64] - interpolation: 'bicubic' #cv2.INTER_CUBIC - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - -optimizer: - name: Adam - beta1: 0.5 + preprocess: + - name: LoadImageFromFile + key: A + - name: Transforms + input_keys: [A] + pipeline: + - name: Resize + size: [64, 64] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] lr_scheduler: - name: linear - learning_rate: 0.00002 + name: LinearDecay + learning_rate: 0.0002 start_epoch: 100 decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimizer_G: + name: Adam + net_names: + - netG + beta1: 0.5 + optimizer_D: + name: Adam + net_names: + - netD + beta1: 0.5 log_config: interval: 100 diff --git a/configs/esrgan_psnr_x4_div2k.yaml b/configs/esrgan_psnr_x4_div2k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..642818e2ca6d9048289a08f4ce9b37d000f62e4d --- /dev/null +++ b/configs/esrgan_psnr_x4_div2k.yaml @@ -0,0 +1,105 @@ +total_iters: 1000000 +output_dir: output_dir +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: BaseSRModel + generator: + name: RRDBNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 23 + pixel_criterion: + name: L1Loss + +dataset: + train: + name: SRDataset + gt_folder: data/DIV2K/DIV2K_train_HR_sub + lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub + num_workers: 4 + batch_size: 16 + scale: 4 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 128 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + test: + name: SRDataset + gt_folder: data/DIV2K/val_set14/Set14 + lq_folder: data/DIV2K/val_set14/Set14_bicLRx4 + scale: 4 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: 0.0002 + periods: [250000, 250000, 250000, 250000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 4 + test_y_channel: false + ssim: + name: SSIM + crop_border: 4 + test_y_channel: false + +log_config: + interval: 10 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/configs/esrgan_x4_div2k.yaml b/configs/esrgan_x4_div2k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a44d8b0ed573e5f03adf3775ba52fe67d3924cb0 --- /dev/null +++ b/configs/esrgan_x4_div2k.yaml @@ -0,0 +1,127 @@ +total_iters: 250000 +output_dir: output_dir +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: ESRGAN + generator: + name: RRDBNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 23 + discriminator: + name: VGGDiscriminator128 + in_channels: 3 + num_feat: 64 + pixel_criterion: + name: L1Loss + loss_weight: !!float 1e-2 + perceptual_criterion: + name: PerceptualLoss + layer_weights: + '34': 1.0 + perceptual_weight: 1.0 + style_weight: 0.0 + norm_img: False + gan_criterion: + name: GANLoss + gan_mode: vanilla + loss_weight: !!float 5e-3 + +dataset: + train: + name: SRDataset + gt_folder: data/DIV2K/DIV2K_train_HR_sub + lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub + num_workers: 6 + batch_size: 32 + scale: 4 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 128 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + test: + name: SRDataset + gt_folder: data/DIV2K/val_set14/Set14 + lq_folder: data/DIV2K/val_set14/Set14_bicLRx4 + scale: 4 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + +lr_scheduler: + name: MultiStepDecay + learning_rate: 0.0001 + milestones: [50000, 100000, 200000, 300000] + gamma: 0.5 + +optimizer: + optimG: + name: Adam + net_names: + - generator + weight_decay: 0.0 + beta1: 0.9 + beta2: 0.99 + optimD: + name: Adam + net_names: + - discriminator + weight_decay: 0.0 + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 4 + test_y_channel: false + ssim: + name: SSIM + crop_border: 4 + test_y_channel: false + +log_config: + interval: 100 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/configs/makeup.yaml b/configs/makeup.yaml index ee48c14f7e4963ab33ce4c016a66b968ff5bcbae..1d68f52ef3751ed55ef4f31e2d92b37eefaa7a90 100644 --- a/configs/makeup.yaml +++ b/configs/makeup.yaml @@ -1,9 +1,6 @@ epochs: 100 output_dir: tmp checkpoints_dir: checkpoints -lambda_A: 10.0 -lambda_B: 10.0 -lambda_identity: 0.5 model: name: MakeupModel @@ -17,7 +14,18 @@ model: n_layers: 3 input_nc: 3 norm_type: spectral - gan_mode: lsgan + cycle_criterion: + name: L1Loss + idt_criterion: + name: L1Loss + loss_weight: 0.5 + l1_criterion: + name: L1Loss + l2_criterion: + name: MSELoss + gan_criterion: + name: GANLoss + gan_mode: lsgan dataset: train: @@ -26,28 +34,42 @@ dataset: dataroot: data/MT-Dataset cls_list: [non-makeup, makeup] phase: train - pool_size: 16 test: name: MakeupDataset trans_size: 256 dataroot: data/MT-Dataset cls_list: [non-makeup, makeup] phase: test - pool_size: 16 -optimizer: - name: Adam - beta1: 0.5 lr_scheduler: - name: linear + name: LinearDecay learning_rate: 0.0002 start_epoch: 100 decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimizer_G: + name: Adam + net_names: + - netG + beta1: 0.5 + optimizer_DA: + name: Adam + net_names: + - netD_A + beta1: 0.5 + optimizer_DB: + name: Adam + net_names: + - netD_B + beta1: 0.5 log_config: interval: 10 visiual_interval: 500 snapshot_config: - interval: 1 + interval: 5 diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index a7bfb121b16181ecf83f3d2b655c4024a6d99d39..0a3fc63ef810faf0e3f3ec6cae93257b400b9ca1 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -1,6 +1,5 @@ epochs: 200 output_dir: output_dir -lambda_L1: 100 model: name: Pix2PixModel @@ -18,70 +17,89 @@ model: n_layers: 3 input_nc: 6 norm_type: batch - gan_mode: vanilla + direction: b2a + pixel_criterion: + name: L1Loss + loss_weight: 100 + gan_criterion: + name: GANLoss + gan_mode: vanilla dataset: train: name: PairedDataset - dataroot: data/cityscapes + dataroot: data/cityscapes/train num_workers: 4 batch_size: 1 - phase: train - max_dataset_size: inf - direction: BtoA - input_nc: 3 - output_nc: 3 - serial_batches: False - pool_size: 0 - transforms: - - name: Resize - size: [286, 286] - interpolation: 'bicubic' #cv2.INTER_CUBIC - keys: [image, image] - - name: PairedRandomCrop - size: [256, 256] - keys: [image, image] - - name: PairedRandomHorizontalFlip - prob: 0.5 - keys: [image, image] - - name: Transpose - keys: [image, image] - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - keys: [image, image] + preprocess: + - name: LoadImageFromFile + key: pair + - name: SplitPairedImage + key: pair + paired_keys: [A, B] + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [286, 286] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: PairedRandomCrop + size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset - dataroot: data/cityscapes/ - phase: test - max_dataset_size: inf - direction: BtoA - input_nc: 3 - output_nc: 3 - serial_batches: True - pool_size: 50 - transforms: - - name: Resize - size: [256, 256] - interpolation: 'bicubic' #cv2.INTER_CUBIC - keys: [image, image] - - name: Transpose - keys: [image, image] - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - keys: [image, image] - - -optimizer: - name: Adam - beta1: 0.5 + dataroot: data/cityscapes/test + num_workers: 4 + batch_size: 1 + load_pipeline: + - name: LoadImageFromFile + key: pair + - name: SplitPairedImage + key: pair + paired_keys: [A, B] + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [256, 256] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] lr_scheduler: - name: linear + name: LinearDecay learning_rate: 0.0002 start_epoch: 100 decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimG: + name: Adam + net_names: + - netG + beta1: 0.5 + optimD: + name: Adam + net_names: + - netD + beta1: 0.5 log_config: interval: 100 diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index 64484387251afc4f8e6c2e7b6309faf0e5a4ea8a..32bf846c3a3fc435830eee374a845c434e6f4e49 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -1,6 +1,5 @@ epochs: 200 output_dir: output_dir -lambda_L1: 100 model: name: Pix2PixModel @@ -18,69 +17,86 @@ model: n_layers: 3 input_nc: 6 norm_type: batch - gan_mode: vanilla + direction: b2a + pixel_criterion: + name: L1Loss + loss_weight: 100 + gan_criterion: + name: GANLoss + gan_mode: vanilla dataset: train: name: PairedDataset - dataroot: data/cityscapes - num_workers: 0 + dataroot: data/cityscapes/train + num_workers: 4 batch_size: 1 - phase: train - max_dataset_size: inf - direction: BtoA - input_nc: 3 - output_nc: 3 - serial_batches: False - pool_size: 0 - transforms: - - name: Resize - size: [286, 286] - interpolation: 'bicubic' #cv2.INTER_CUBIC - keys: [image, image] - - name: PairedRandomCrop - size: [256, 256] - keys: [image, image] - - name: PairedRandomHorizontalFlip - prob: 0.5 - keys: [image, image] - - name: Transpose - keys: [image, image] - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - keys: [image, image] + preprocess: + - name: LoadImageFromFile + key: pair + - name: SplitPairedImage + key: pair + paired_keys: [A, B] + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [286, 286] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: PairedRandomCrop + size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset - dataroot: data/cityscapes/ - phase: test - max_dataset_size: inf - direction: BtoA - input_nc: 3 - output_nc: 3 - serial_batches: True - pool_size: 50 - transforms: - - name: Resize - size: [256, 256] - interpolation: 'bicubic' #cv2.INTER_CUBIC - keys: [image, image] - - name: Transpose - keys: [image, image] - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - keys: [image, image] - -optimizer: - name: Adam - beta1: 0.5 + dataroot: data/cityscapes/test + num_workers: 4 + batch_size: 1 + load_pipeline: + - name: LoadImageFromFile + key: pair + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [256, 256] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] lr_scheduler: - name: linear + name: LinearDecay learning_rate: 0.0004 start_epoch: 100 decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimG: + name: Adam + net_names: + - netG + beta1: 0.5 + optimD: + name: Adam + net_names: + - netD + beta1: 0.5 log_config: interval: 100 diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 1db2c13d656f47c8aa978f43c0f8aebba891b123..0e044f9cd037850d1d52beaf1870029a80d86630 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -1,6 +1,5 @@ epochs: 200 output_dir: output_dir -lambda_L1: 100 model: name: Pix2PixModel @@ -18,69 +17,86 @@ model: n_layers: 3 input_nc: 6 norm_type: batch - gan_mode: vanilla + direction: b2a + pixel_criterion: + name: L1Loss + loss_weight: 100 + gan_criterion: + name: GANLoss + gan_mode: vanilla dataset: train: name: PairedDataset - dataroot: data/facades/ - num_workers: 0 + dataroot: data/facades/train + num_workers: 4 batch_size: 1 - phase: train - max_dataset_size: inf - direction: BtoA - input_nc: 3 - output_nc: 3 - serial_batches: False - pool_size: 0 - transforms: - - name: Resize - size: [286, 286] - interpolation: 'bicubic' #cv2.INTER_CUBIC - keys: [image, image] - - name: PairedRandomCrop - size: [256, 256] - keys: [image, image] - - name: PairedRandomHorizontalFlip - prob: 0.5 - keys: [image, image] - - name: Transpose - keys: [image, image] - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - keys: [image, image] + preprocess: + - name: LoadImageFromFile + key: pair + - name: SplitPairedImage + key: pair + paired_keys: [A, B] + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [286, 286] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: PairedRandomCrop + size: [256, 256] + keys: [image, image] + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] test: name: PairedDataset - dataroot: data/facades/ - phase: test - max_dataset_size: inf - direction: BtoA - input_nc: 3 - output_nc: 3 - serial_batches: True - pool_size: 50 - transforms: - - name: Resize - size: [256, 256] - interpolation: 'bicubic' #cv2.INTER_CUBIC - keys: [image, image] - - name: Transpose - keys: [image, image] - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - keys: [image, image] - -optimizer: - name: Adam - beta1: 0.5 + dataroot: data/facades/test + num_workers: 4 + batch_size: 1 + load_pipeline: + - name: LoadImageFromFile + key: pair + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [256, 256] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] lr_scheduler: - name: linear + name: LinearDecay learning_rate: 0.0002 start_epoch: 100 decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimG: + name: Adam + net_names: + - netG + beta1: 0.5 + optimD: + name: Adam + net_names: + - netD + beta1: 0.5 log_config: interval: 100 diff --git a/configs/ugatit_selfie2anime_light.yaml b/configs/ugatit_selfie2anime_light.yaml index ad0de434dd78eae2613ccf5993fe1e3a034422a2..a9d21be739b047d16a1f240ae6c5763f97ee04d8 100644 --- a/configs/ugatit_selfie2anime_light.yaml +++ b/configs/ugatit_selfie2anime_light.yaml @@ -1,9 +1,5 @@ epochs: 300 output_dir: output_dir -adv_weight: 1.0 -cycle_weight: 10.0 -identity_weight: 10.0 -cam_weight: 1000.0 model: name: UGATITModel @@ -25,57 +21,102 @@ model: input_nc: 3 ndf: 64 n_layers: 5 + l1_criterion: + name: L1Loss + mse_criterion: + name: MSELoss + bce_criterion: + name: BCEWithLogitsLoss + adv_weight: 1.0 + cycle_weight: 10.0 + identity_weight: 10.0 + cam_weight: 1000.0 dataset: train: name: UnpairedDataset - dataroot: data/selfie2anime + dataroot_a: data/selfie2anime/trainA + dataroot_b: data/selfie2anime/trainB num_workers: 0 - phase: train - max_dataset_size: inf - direction: AtoB - input_nc: 3 - output_nc: 3 - serial_batches: False - transforms: - - name: Resize - size: [286, 286] - interpolation: 'bilinear' #'bicubic' #cv2.INTER_CUBIC - - name: RandomCrop - size: [256, 256] - - name: RandomHorizontalFlip - prob: 0.5 - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] + batch_size: 1 + is_train: True + max_size: inf + preprocess: + - name: LoadImageFromFile + key: A + - name: LoadImageFromFile + key: B + - name: Transforms + input_keys: [A, B] + pipeline: + - name: Resize + size: [286, 286] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: ['image', 'image'] + - name: RandomCrop + size: [256, 256] + keys: ['image', 'image'] + - name: RandomHorizontalFlip + prob: 0.5 + keys: ['image', 'image'] + - name: Transpose + keys: ['image', 'image'] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: ['image', 'image'] test: - name: SingleDataset - dataroot: data/selfie2anime/testA - max_dataset_size: inf - direction: AtoB - input_nc: 3 - output_nc: 3 - serial_batches: False - transforms: - - name: Resize - size: [256, 256] - interpolation: 'bilinear' #cv2.INTER_CUBIC - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - -optimizer: - name: Adam - beta1: 0.5 - weight_decay: 0.0001 + name: UnpairedDataset + dataroot_a: data/selfie2anime/testA + dataroot_b: data/selfie2anime/testB + num_workers: 0 + batch_size: 1 + max_size: inf + is_train: False + preprocess: + - name: LoadImageFromFile + key: A + - name: LoadImageFromFile + key: B + - name: Transfroms + input_keys: [A, B] + pipeline: + - name: Resize + size: [256, 256] + interpolation: 'bicubic' #cv2.INTER_CUBIC + keys: ['image', 'image'] + - name: Transpose + keys: ['image', 'image'] + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: ['image', 'image'] lr_scheduler: - name: linear + name: LinearDecay learning_rate: 0.0001 start_epoch: 150 decay_epochs: 150 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimG: + name: Adam + net_names: + - genA2B + - genB2A + weight_decay: 0.0001 + beta1: 0.5 + optimD: + name: Adam + net_names: + - disGA + - disGB + - disLA + - disLB + weight_decay: 0.0001 + beta1: 0.5 log_config: interval: 10 diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index 5b9d9568837d04be1ace8b6516a3204f95f861f8..59d6d4c233d805673e0ca67324f94aea230a5576 100644 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -15,7 +15,7 @@ from .unpaired_dataset import UnpairedDataset from .single_dataset import SingleDataset from .paired_dataset import PairedDataset -from .sr_image_dataset import SRImageDataset +from .base_sr_dataset import SRDataset from .makeup_dataset import MakeupDataset from .common_vision_dataset import CommonVisionDataset from .animeganv2_dataset import AnimeGANV2Dataset diff --git a/ppgan/datasets/base_dataset.py b/ppgan/datasets/base_dataset.py index 93e9577b900a74448f0af194fa2b5719a87707b3..8ea7b8b0063b5f396575dfe1a7fe222a8267ec6a 100644 --- a/ppgan/datasets/base_dataset.py +++ b/ppgan/datasets/base_dataset.py @@ -12,105 +12,124 @@ # See the License for the specific language governing permissions and # limitations under the License. -# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix -import random -import numpy as np +import os +from pathlib import Path +from abc import ABCMeta, abstractmethod from paddle.io import Dataset -from PIL import Image -import cv2 -import paddle.vision.transforms as transforms -from .transforms import transforms as T -from abc import ABC, abstractmethod +from .preprocess import build_preprocess +IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', + '.PPM', '.bmp', '.BMP') -class BaseDataset(Dataset, ABC): - """This class is an abstract base class (ABC) for datasets. + +def scandir(dir_path, suffix=None, recursive=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + + Returns: + A generator for all the interested files with relative pathes. """ - def __init__(self, cfg): - """Initialize the class; save the options in the class + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') - Args: - cfg (dict) -- stores all the experiment flags - """ - self.cfg = cfg - self.root = cfg.dataroot + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') - @abstractmethod - def __len__(self): - """Return the total number of images in the dataset.""" - return 0 + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = os.path.relpath(entry.path, root) + if suffix is None: + yield rel_path + elif rel_path.endswith(suffix): + yield rel_path + else: + if recursive: + yield from _scandir(entry.path, + suffix=suffix, + recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +class BaseDataset(Dataset, metaclass=ABCMeta): + """Base class for datasets. + + All datasets should subclass it. + All subclasses should overwrite: + + ``prepare_data_infos``, supporting to load information and generate + image lists. + + Args: + preprocess (list[dict]): A sequence of data preprocess config. + + """ + def __init__(self, preprocess=None): + super(BaseDataset, self).__init__() + + if preprocess: + self.preprocess = build_preprocess(preprocess) @abstractmethod - def __getitem__(self, index): - """Return a data point and its metadata information. + def prepare_data_infos(self): + """Abstract function for loading annotation. + + All subclasses should overwrite this function + should set self.annotations in this fucntion + data_infos should be as list of dict: + [{key_path: file_path}, {key_path: file_path}, {key_path: file_path}] + """ + self.data_infos = None + + @staticmethod + def scan_folder(path): + """Obtain sample path list (including sub-folders) from a given folder. - Parameters: - index - - a random integer for data indexing + Args: + path (str|pathlib.Path): Folder path. Returns: - a dictionary of data with their names. It ususally contains the data itself and its metadata information. + list[str]: sample list obtained form given folder. """ - pass - - -def get_params(cfg, size): - w, h = size - new_h = h - new_w = w - if cfg.preprocess == 'resize_and_crop': - new_h = new_w = cfg.load_size - elif cfg.preprocess == 'scale_width_and_crop': - new_w = cfg.load_size - new_h = cfg.load_size * h // w - - x = random.randint(0, np.maximum(0, new_w - cfg.crop_size)) - y = random.randint(0, np.maximum(0, new_h - cfg.crop_size)) - - flip = random.random() > 0.5 - - return {'crop_pos': (x, y), 'flip': flip} - - -def get_transform(cfg, - params=None, - grayscale=False, - method=cv2.INTER_CUBIC, - convert=True): - transform_list = [] - if grayscale: - print('grayscale not support for now!!!') - pass - if 'resize' in cfg.preprocess: - osize = (cfg.load_size, cfg.load_size) - transform_list.append(transforms.Resize(osize, method)) - elif 'scale_width' in cfg.preprocess: - print('scale_width not support for now!!!') - pass - - if 'crop' in cfg.preprocess: - - if params is None: - transform_list.append(T.RandomCrop(cfg.crop_size)) + + if isinstance(path, (str, Path)): + path = str(path) else: - transform_list.append(T.Crop(params['crop_pos'], cfg.crop_size)) - - if cfg.preprocess == 'none': - print('preprocess not support for now!!!') - pass - - if not cfg.no_flip: - if params is None: - transform_list.append(transforms.RandomHorizontalFlip()) - elif params['flip']: - transform_list.append(transforms.RandomHorizontalFlip(1.0)) - - if convert: - transform_list += [transforms.Permute(to_rgb=True)] - if cfg.get('normalize', None): - transform_list += [ - transforms.Normalize(cfg.normalize.mean, cfg.normalize.std) - ] - - return transforms.Compose(transform_list) + raise TypeError("'path' must be a str or a Path object, " + f'but received {type(path)}.') + + samples = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True)) + samples = [os.path.join(path, v) for v in samples] + assert samples, '{} has no valid image file.'.format(path) + return samples + + def __getitem__(self, idx): + datas = self.data_infos[idx] + + if hasattr(self, 'preprocess') and self.preprocess: + datas = self.preprocess(datas) + + return datas + + def __len__(self): + """Length of the dataset. + + Returns: + int: Length of the dataset. + """ + return len(self.data_infos) diff --git a/ppgan/datasets/base_sr_dataset.py b/ppgan/datasets/base_sr_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..306ad9ade1260f95158f848d3b34813ed8205040 --- /dev/null +++ b/ppgan/datasets/base_sr_dataset.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy + +from pathlib import Path +from .base_dataset import BaseDataset +from .builder import DATASETS + + +@DATASETS.register() +class SRDataset(BaseDataset): + """Base super resulotion dataset for image restoration.""" + def __init__(self, + lq_folder, + gt_folder, + preprocess, + scale, + filename_tmpl='{}'): + super(SRDataset, self).__init__(preprocess) + self.lq_folder = lq_folder + self.gt_folder = gt_folder + self.scale = scale + self.filename_tmpl = filename_tmpl + + self.prepare_data_infos() + + def prepare_data_infos(self): + """Load annoations for SR dataset. + + It loads the LQ and GT image path from folders. + + Returns: + dict: Returned dict for LQ and GT pairs. + """ + self.data_infos = [] + lq_paths = self.scan_folder(self.lq_folder) + gt_paths = self.scan_folder(self.gt_folder) + assert len(lq_paths) == len(gt_paths), ( + f'gt and lq datasets have different number of images: ' + f'{len(lq_paths)}, {len(gt_paths)}.') + for gt_path in gt_paths: + basename, ext = os.path.splitext(os.path.basename(gt_path)) + lq_path = os.path.join(self.lq_folder, + (f'{self.filename_tmpl.format(basename)}' + f'{ext}')) + assert lq_path in lq_paths, f'{lq_path} is not in lq_paths.' + self.data_infos.append(dict(lq_path=lq_path, gt_path=gt_path)) + return self.data_infos diff --git a/ppgan/datasets/builder.py b/ppgan/datasets/builder.py index ac994ab0f69c04aad016a59134ed3d4a32b46d00..8fe6c7ee9de1fb66c03e223a8217a8cbc0ab565e 100644 --- a/ppgan/datasets/builder.py +++ b/ppgan/datasets/builder.py @@ -66,22 +66,37 @@ class DictDataset(paddle.io.Dataset): class DictDataLoader(): - def __init__(self, dataset, batch_size, is_train, num_workers=4): + def __init__(self, + dataset, + batch_size, + is_train, + num_workers=4, + distributed=True): self.dataset = DictDataset(dataset) place = paddle.CUDAPlace(ParallelEnv().dev_id) \ if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0) - sampler = DistributedBatchSampler(self.dataset, - batch_size=batch_size, - shuffle=True if is_train else False, - drop_last=True if is_train else False) - - self.dataloader = paddle.io.DataLoader(self.dataset, - batch_sampler=sampler, - places=place, - num_workers=num_workers) + if distributed: + sampler = DistributedBatchSampler( + self.dataset, + batch_size=batch_size, + shuffle=True if is_train else False, + drop_last=True if is_train else False) + + self.dataloader = paddle.io.DataLoader(self.dataset, + batch_sampler=sampler, + places=place, + num_workers=num_workers) + else: + self.dataloader = paddle.io.DataLoader( + self.dataset, + batch_size=batch_size, + shuffle=True if is_train else False, + drop_last=True if is_train else False, + places=place, + num_workers=num_workers) self.batch_size = batch_size @@ -117,12 +132,20 @@ class DictDataLoader(): return current_items -def build_dataloader(cfg, is_train=True): - dataset = DATASETS.get(cfg.name)(cfg) +def build_dataloader(cfg, is_train=True, distributed=True): + cfg_ = cfg.copy() + + batch_size = cfg_.pop('batch_size', 1) + num_workers = cfg_.pop('num_workers', 0) + + name = cfg_.pop('name') - batch_size = cfg.get('batch_size', 1) - num_workers = cfg.get('num_workers', 0) + dataset = DATASETS.get(name)(**cfg_) - dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers) + dataloader = DictDataLoader(dataset, + batch_size, + is_train, + num_workers, + distributed=distributed) return dataloader diff --git a/ppgan/datasets/makeup_dataset.py b/ppgan/datasets/makeup_dataset.py index 56cdd9197352dd2639bbedccc14efb1d707fd714..4236aa1eab2035e7af19440d66b1d2d5f359fbce 100644 --- a/ppgan/datasets/makeup_dataset.py +++ b/ppgan/datasets/makeup_dataset.py @@ -13,35 +13,38 @@ # limitations under the License. import cv2 -import os.path -from .base_dataset import BaseDataset, get_transform -from .transforms.makeup_transforms import get_makeup_transform -import paddle.vision.transforms as T -from PIL import Image import random +import os.path import numpy as np +from PIL import Image + +import paddle +import paddle.vision.transforms as T +from .base_dataset import BaseDataset from ..utils.preprocess import * from .builder import DATASETS @DATASETS.register() -class MakeupDataset(BaseDataset): - def __init__(self, cfg): - """Initialize this dataset class. +class MakeupDataset(paddle.io.Dataset): + def __init__(self, dataroot, phase, trans_size, cls_list): + """Initialize psgan dataset class. - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + Args: + dataroot (str): Directory of dataset. + phase (str): 'train' or 'test'. """ - BaseDataset.__init__(self, cfg) - self.image_path = cfg.dataroot - self.mode = cfg.phase - self.transform = get_makeup_transform(cfg) + self.image_path = dataroot + self.mode = phase + self.trans_size = trans_size + self.cls_list = cls_list + self.transform = self.build_makeup_transform() self.norm = T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]) - self.transform_mask = get_makeup_transform(cfg, pic="mask") - self.trans_size = cfg.trans_size - self.cls_list = cfg.cls_list + self.transform_mask = self.build_makeup_transform("mask") + self.trans_size = trans_size + self.cls_A = self.cls_list[0] self.cls_B = self.cls_list[1] for cls in self.cls_list: @@ -72,6 +75,18 @@ class MakeupDataset(BaseDataset): getattr(self, cls + "_mask_filenames").append(splits[1]) getattr(self, cls + "_lmks_filenames").append(splits[2]) + def build_makeup_transform(self, pic="image"): + if pic == "image": + transform = T.Compose([ + T.Resize(size=self.trans_size), + T.Transpose(), + ]) + else: + transform = T.Resize(size=self.trans_size, + interpolation=cv2.INTER_NEAREST) + + return transform + def __getitem__(self, index): """Return MANet and MDNet needed params. diff --git a/ppgan/datasets/paired_dataset.py b/ppgan/datasets/paired_dataset.py index 2c6be37c56d9ca59ff6c5dbda3f0b05b47cb04a4..503d920276a1389d2c2dc8a4f525412b9bb19d2a 100644 --- a/ppgan/datasets/paired_dataset.py +++ b/ppgan/datasets/paired_dataset.py @@ -12,65 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import paddle -import os.path -from .base_dataset import BaseDataset, get_params, get_transform -from .image_folder import make_dataset - from .builder import DATASETS -from .transforms.builder import build_transforms +from .base_dataset import BaseDataset @DATASETS.register() class PairedDataset(BaseDataset): """A dataset class for paired image dataset. """ - def __init__(self, cfg): + def __init__(self, dataroot, preprocess): """Initialize this dataset class. Args: - cfg (dict): configs of datasets. - """ - BaseDataset.__init__(self, cfg) - self.dir_AB = os.path.join(cfg.dataroot, - cfg.phase) # get the image directory - self.AB_paths = sorted(make_dataset( - self.dir_AB, cfg.max_dataset_size)) # get image paths - - self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc - self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc - self.transforms = build_transforms(cfg.transforms) - - def __getitem__(self, index): - """Return a data point and its metadata information. + dataroot (str): Directory of dataset. + preprocess (list[dict]): A sequence of data preprocess config. - Parameters: - index - - a random integer for data indexing - - Returns a dictionary that contains A, B, A_paths and B_paths - A (tensor) - - an image in the input domain - B (tensor) - - its corresponding image in the target domain - A_paths (str) - - image paths - B_paths (str) - - image paths (same as A_paths) """ - # read a image given a random integer index - AB_path = self.AB_paths[index] - AB = cv2.cvtColor(cv2.imread(AB_path), cv2.COLOR_BGR2RGB) - - # split AB image into A and B - h, w = AB.shape[:2] - # w, h = AB.size - w2 = int(w / 2) + super(PairedDataset, self).__init__(preprocess) + self.dataroot = dataroot + self.data_infos = self.prepare_data_infos() - A = AB[:h, :w2, :] - B = AB[:h, w2:, :] + def prepare_data_infos(self): + """Load paired image paths. - # apply the same transform to both A and B - A, B = self.transforms((A, B)) - - return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} + Returns: + list[dict]: List that contains paired image paths. + """ + data_infos = [] + pair_paths = sorted(self.scan_folder(self.dataroot)) + for pair_path in pair_paths: + data_infos.append(dict(pair_path=pair_path)) - def __len__(self): - """Return the total number of images in the dataset.""" - return len(self.AB_paths) + return data_infos diff --git a/ppgan/datasets/preprocess/__init__.py b/ppgan/datasets/preprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0876163fde7d374aacdc9c900b565fb6ed1fd268 --- /dev/null +++ b/ppgan/datasets/preprocess/__init__.py @@ -0,0 +1,6 @@ +from .io import LoadImageFromFile +from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip, + PairedRandomVerticalFlip, PairedRandomTransposeHW, + SRPairedRandomCrop, SplitPairedImage) + +from .builder import build_preprocess diff --git a/ppgan/datasets/preprocess/builder.py b/ppgan/datasets/preprocess/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..e25147c8c6bfb0d6206aa93a2a905ee411183040 --- /dev/null +++ b/ppgan/datasets/preprocess/builder.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import traceback + +from ...utils.registry import Registry, build_from_config + +LOAD_PIPELINE = Registry("LOAD_PIPELINE") +TRANSFORMS = Registry("TRANSFORM") +PREPROCESS = Registry("PREPROCESS") + + +class Compose(object): + """ + Composes several transforms together use for composing list of transforms + together for a dataset transform. + + Args: + functions (list[callable]): List of functions to compose. + + Returns: + A compose object which is callable, __call__ for this Compose + object will call each given :attr:`transforms` sequencely. + + """ + def __init__(self, functions): + self.functions = functions + + def __call__(self, datas): + + for func in self.functions: + try: + datas = func(datas) + except Exception as e: + stack_info = traceback.format_exc() + print("fail to perform fuction [{}] with error: " + "{} and stack:\n{}".format(func, e, str(stack_info))) + raise RuntimeError + return datas + + +def build_preprocess(cfg): + preproccess = [] + if not isinstance(cfg, (list, tuple)): + cfg = [cfg] + + for cfg_ in cfg: + process = build_from_config(cfg_, PREPROCESS) + preproccess.append(process) + + preproccess = Compose(preproccess) + return preproccess diff --git a/ppgan/datasets/preprocess/io.py b/ppgan/datasets/preprocess/io.py new file mode 100644 index 0000000000000000000000000000000000000000..5857d58796aa871bcdbcbcd80cfa0a93e19dc89d --- /dev/null +++ b/ppgan/datasets/preprocess/io.py @@ -0,0 +1,57 @@ +import cv2 + +from .builder import PREPROCESS + + +@PREPROCESS.register() +class LoadImageFromFile(object): + """Load image from file. + + Args: + key (str): Keys in results to find corresponding path. Default: 'image'. + flag (str): Loading flag for images. Default: -1. + to_rgb (str): Convert img to 'rgb' format. Default: True. + backend (str): io backend where images are store. Default: None. + save_original_img (bool): If True, maintain a copy of the image in + `results` dict with name of `f'ori_{key}'`. Default: False. + kwargs (dict): Args for file client. + """ + def __init__(self, + key='image', + flag=-1, + to_rgb=True, + save_original_img=False, + backend=None, + **kwargs): + self.key = key + self.flag = flag + self.to_rgb = to_rgb + self.backend = backend + self.save_original_img = save_original_img + self.kwargs = kwargs + + def __call__(self, results): + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + + filepath = str(results[f'{self.key}_path']) + #TODO: use file client to manage io backend + # such as opencv, pil, imdb + img = cv2.imread(filepath, self.flag) + if self.to_rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + results[self.key] = img + results[f'{self.key}_path'] = filepath + results[f'{self.key}_ori_shape'] = img.shape + if self.save_original_img: + results[f'ori_{self.key}'] = img.copy() + + return results diff --git a/ppgan/datasets/preprocess/transforms.py b/ppgan/datasets/preprocess/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..57daf44a43d0d7b9f1464949fcd1cdc6565844cc --- /dev/null +++ b/ppgan/datasets/preprocess/transforms.py @@ -0,0 +1,218 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import random +import numbers +import collections + +import paddle.vision.transforms as T +import paddle.vision.transforms.functional as F + +from .builder import TRANSFORMS, build_from_config +from .builder import PREPROCESS + +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + +TRANSFORMS.register(T.Resize) +TRANSFORMS.register(T.RandomCrop) +TRANSFORMS.register(T.RandomHorizontalFlip) +TRANSFORMS.register(T.RandomVerticalFlip) +TRANSFORMS.register(T.Normalize) +TRANSFORMS.register(T.Transpose) + + +@PREPROCESS.register() +class Transforms(): + def __init__(self, pipeline, input_keys): + self.input_keys = input_keys + self.transforms = [] + for transform_cfg in pipeline: + self.transforms.append(build_from_config(transform_cfg, TRANSFORMS)) + + def __call__(self, datas): + data = [] + for k in self.input_keys: + data.append(datas[k]) + data = tuple(data) + for transform in self.transforms: + data = transform(data) + + if hasattr(transform, 'params') and isinstance( + transform.params, dict): + datas.update(transform.params) + + for i, k in enumerate(self.input_keys): + datas[k] = data[i] + + return datas + + +@PREPROCESS.register() +class SplitPairedImage: + def __init__(self, key, paired_keys=['A', 'B']): + self.key = key + self.paired_keys = paired_keys + + def __call__(self, datas): + # split AB image into A and B + h, w = datas[self.key].shape[:2] + # w, h = AB.size + w2 = int(w / 2) + + a, b = self.paired_keys + datas[a] = datas[self.key][:h, :w2, :] + datas[b] = datas[self.key][:h, w2:, :] + + datas[a + '_path'] = datas[self.key + '_path'] + datas[b + '_path'] = datas[self.key + '_path'] + + return datas + + +@TRANSFORMS.register() +class PairedRandomCrop(T.RandomCrop): + def __init__(self, size, keys=None): + super().__init__(size, keys=keys) + + if isinstance(size, int): + self.size = (size, size) + else: + self.size = size + + def _get_params(self, inputs): + image = inputs[self.keys.index('image')] + params = {} + params['crop_prams'] = self._get_param(image, self.size) + return params + + def _apply_image(self, img): + i, j, h, w = self.params['crop_prams'] + return F.crop(img, i, j, h, w) + + +@TRANSFORMS.register() +class PairedRandomHorizontalFlip(T.RandomHorizontalFlip): + def __init__(self, prob=0.5, keys=None): + super().__init__(prob, keys=keys) + + def _get_params(self, inputs): + params = {} + params['flip'] = random.random() < self.prob + return params + + def _apply_image(self, image): + if self.params['flip']: + return F.hflip(image) + return image + + +@TRANSFORMS.register() +class PairedRandomVerticalFlip(T.RandomHorizontalFlip): + def __init__(self, prob=0.5, keys=None): + super().__init__(prob, keys=keys) + + def _get_params(self, inputs): + params = {} + params['flip'] = random.random() < self.prob + return params + + def _apply_image(self, image): + if self.params['flip']: + return F.hflip(image) + return image + + +@TRANSFORMS.register() +class PairedRandomTransposeHW(T.BaseTransform): + """Randomly transpose images in H and W dimensions with a probability. + + (TransposeHW = horizontal flip + anti-clockwise rotatation by 90 degrees) + When used with horizontal/vertical flips, it serves as a way of rotation + augmentation. + It also supports randomly transposing a list of images. + + Required keys are the keys in attributes "keys", added or modified keys are + "transpose" and the keys in attributes "keys". + + Args: + prob (float): The propability to transpose the images. + keys (list[str]): The images to be transposed. + """ + def __init__(self, prob=0.5, keys=None): + self.keys = keys + self.prob = prob + + def _get_params(self, inputs): + params = {} + params['transpose'] = random.random() < self.prob + return params + + def _apply_image(self, image): + if self.params['transpose']: + image = image.transpose(1, 0, 2) + return image + + +@TRANSFORMS.register() +class SRPairedRandomCrop(T.BaseTransform): + """Super resolution random crop. + + It crops a pair of lq and gt images with corresponding locations. + It also supports accepting lq list and gt list. + Required keys are "scale", "lq", and "gt", + added or modified keys are "lq" and "gt". + + Args: + scale (int): model upscale factor. + gt_patch_size (int): cropped gt patch size. + """ + def __init__(self, scale, gt_patch_size, keys=None): + self.gt_patch_size = gt_patch_size + self.scale = scale + self.keys = keys + + def __call__(self, inputs): + """inputs must be (lq_img, gt_img)""" + scale = self.scale + lq_patch_size = self.gt_patch_size // scale + + lq = inputs[0] + gt = inputs[1] + + h_lq, w_lq, _ = lq.shape + h_gt, w_gt, _ = gt.shape + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError('scale size not match') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError('lq size error') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + # crop lq patch + lq = lq[top:top + lq_patch_size, left:left + lq_patch_size, ...] + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + gt = gt[top_gt:top_gt + self.gt_patch_size, + left_gt:left_gt + self.gt_patch_size, ...] + + outputs = (lq, gt) + return outputs diff --git a/ppgan/datasets/single_dataset.py b/ppgan/datasets/single_dataset.py index 246f769259cae4d98105dbfd4f00b271eb5f1bd9..ad67c440d93a4e8a50ff541b03471d6aafa724b1 100644 --- a/ppgan/datasets/single_dataset.py +++ b/ppgan/datasets/single_dataset.py @@ -12,54 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import paddle -from .base_dataset import BaseDataset, get_transform -from .image_folder import make_dataset - +from .base_dataset import BaseDataset from .builder import DATASETS -from .transforms.builder import build_transforms @DATASETS.register() class SingleDataset(BaseDataset): """ """ - def __init__(self, cfg): - """Initialize this dataset class. + def __init__(self, dataroot, preprocess): + """Initialize single dataset class. Args: - cfg (dict) -- stores all the experiment flags + dataroot (str): Directory of dataset. + preprocess (list[dict]): A sequence of data preprocess config. """ - BaseDataset.__init__(self, cfg) - self.A_paths = sorted(make_dataset(cfg.dataroot, cfg.max_dataset_size)) - input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc - self.transform = build_transforms(self.cfg.transforms) - - def __getitem__(self, index): - """Return a data point and its metadata information. + super(SingleDataset).__init__(self, preprocess) + self.dataroot = dataroot + self.data_infos = self.prepare_data_infos() - Parameters: - index - - a random integer for data indexing + def prepare_data_infos(self): + """prepare image paths from a folder. - Returns a dictionary that contains A and A_paths - A(tensor) - - an image in one domain - A_paths(str) - - the path of the image + Returns: + list[dict]: List that contains paired image paths. """ - A_path = self.A_paths[index] - A_img = cv2.cvtColor(cv2.imread(A_path), cv2.COLOR_BGR2RGB) - A = self.transform(A_img) - - return {'A': A, 'A_paths': A_path} - - def __len__(self): - """Return the total number of images in the dataset.""" - return len(self.A_paths) + data_infos = [] + paths = sorted(self.scan_folder(self.dataroot)) + for path in paths: + data_infos.append(dict(A_path=path)) - def get_path_by_indexs(self, indexs): - if isinstance(indexs, paddle.Tensor): - indexs = indexs.numpy() - current_paths = [] - for index in indexs: - current_paths.append(self.A_paths[index]) - return current_paths + return data_infos diff --git a/ppgan/datasets/sr_image_dataset.py b/ppgan/datasets/sr_image_dataset.py deleted file mode 100644 index b642ef1cacb98b94541b0febb3ce51df8370df7f..0000000000000000000000000000000000000000 --- a/ppgan/datasets/sr_image_dataset.py +++ /dev/null @@ -1,258 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import cv2 -import random -import numpy as np -import paddle.vision.transforms as transform - -from pathlib import Path -from paddle.io import Dataset -from .builder import DATASETS - - -def scandir(dir_path, suffix=None, recursive=False): - """Scan a directory to find the interested files. - """ - if isinstance(dir_path, (str, Path)): - dir_path = str(dir_path) - else: - raise TypeError('"dir_path" must be a string or Path object') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - root = dir_path - - def _scandir(dir_path, suffix, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - rel_path = os.path.relpath(entry.path, root) - if suffix is None: - yield rel_path - elif rel_path.endswith(suffix): - yield rel_path - else: - if recursive: - yield from _scandir(entry.path, - suffix=suffix, - recursive=recursive) - else: - continue - - return _scandir(dir_path, suffix=suffix, recursive=recursive) - - -def paired_paths_from_folder(folders, keys, filename_tmpl): - """Generate paired paths from folders. - """ - assert len(folders) == 2, ( - 'The len of folders should be 2 with [input_folder, gt_folder]. ' - f'But got {len(folders)}') - assert len(keys) == 2, ( - 'The len of keys should be 2 with [input_key, gt_key]. ' - f'But got {len(keys)}') - input_folder, gt_folder = folders - input_key, gt_key = keys - - input_paths = list(scandir(input_folder)) - gt_paths = list(scandir(gt_folder)) - assert len(input_paths) == len(gt_paths), ( - f'{input_key} and {gt_key} datasets have different number of images: ' - f'{len(input_paths)}, {len(gt_paths)}.') - paths = [] - for gt_path in gt_paths: - basename, ext = os.path.splitext(os.path.basename(gt_path)) - input_name = f'{filename_tmpl.format(basename)}{ext}' - input_path = os.path.join(input_folder, input_name) - assert input_name in input_paths, (f'{input_name} is not in ' - f'{input_key}_paths.') - gt_path = os.path.join(gt_folder, gt_path) - paths.append( - dict([(f'{input_key}_path', input_path), - (f'{gt_key}_path', gt_path)])) - return paths - - -def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): - """Paired random crop. - - It crops lists of lq and gt images with corresponding locations. - - Args: - img_gts (list[ndarray] | ndarray): GT images. Note that all images - should have the same shape. If the input is an ndarray, it will - be transformed to a list containing itself. - img_lqs (list[ndarray] | ndarray): LQ images. Note that all images - should have the same shape. If the input is an ndarray, it will - be transformed to a list containing itself. - gt_patch_size (int): GT patch size. - scale (int): Scale factor. - gt_path (str): Path to ground-truth. - - Returns: - list[ndarray] | ndarray: GT images and LQ images. If returned results - only have one element, just return ndarray. - """ - - if not isinstance(img_gts, list): - img_gts = [img_gts] - if not isinstance(img_lqs, list): - img_lqs = [img_lqs] - - h_lq, w_lq, _ = img_lqs[0].shape - h_gt, w_gt, _ = img_gts[0].shape - lq_patch_size = gt_patch_size // scale - - if h_gt != h_lq * scale or w_gt != w_lq * scale: - raise ValueError( - f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', - f'multiplication of LQ ({h_lq}, {w_lq}).') - if h_lq < lq_patch_size or w_lq < lq_patch_size: - raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' - f'({lq_patch_size}, {lq_patch_size}). ' - f'Please remove {gt_path}.') - - # randomly choose top and left coordinates for lq patch - top = random.randint(0, h_lq - lq_patch_size) - left = random.randint(0, w_lq - lq_patch_size) - - # crop lq patch - img_lqs = [ - v[top:top + lq_patch_size, left:left + lq_patch_size, ...] - for v in img_lqs - ] - - # crop corresponding gt patch - top_gt, left_gt = int(top * scale), int(left * scale) - img_gts = [ - v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] - for v in img_gts - ] - if len(img_gts) == 1: - img_gts = img_gts[0] - if len(img_lqs) == 1: - img_lqs = img_lqs[0] - return img_gts, img_lqs - - -def augment(imgs, hflip=True, rotation=True, flows=None): - """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). - """ - hflip = hflip and random.random() < 0.5 - vflip = rotation and random.random() < 0.5 - rot90 = rotation and random.random() < 0.5 - - def _augment(img): - if hflip: - cv2.flip(img, 1, img) - if vflip: - cv2.flip(img, 0, img) - if rot90: - img = img.transpose(1, 0, 2) - return img - - def _augment_flow(flow): - if hflip: - cv2.flip(flow, 1, flow) - flow[:, :, 0] *= -1 - if vflip: - cv2.flip(flow, 0, flow) - flow[:, :, 1] *= -1 - if rot90: - flow = flow.transpose(1, 0, 2) - flow = flow[:, :, [1, 0]] - return flow - - if not isinstance(imgs, list): - imgs = [imgs] - imgs = [_augment(img) for img in imgs] - if len(imgs) == 1: - imgs = imgs[0] - - if flows is not None: - if not isinstance(flows, list): - flows = [flows] - flows = [_augment_flow(flow) for flow in flows] - if len(flows) == 1: - flows = flows[0] - return imgs, flows - else: - return imgs - - -@DATASETS.register() -class SRImageDataset(Dataset): - """Paired image dataset for image restoration.""" - def __init__(self, cfg): - super(SRImageDataset, self).__init__() - self.cfg = cfg - - self.file_client = None - self.io_backend_opt = cfg['io_backend'] - - self.gt_folder, self.lq_folder = cfg['dataroot_gt'], cfg['dataroot_lq'] - if 'filename_tmpl' in cfg: - self.filename_tmpl = cfg['filename_tmpl'] - else: - self.filename_tmpl = '{}' - - if self.io_backend_opt['type'] == 'lmdb': - #TODO: LielinJiang support lmdb to accelerate io - pass - elif 'meta_info_file' in self.cfg and self.cfg[ - 'meta_info_file'] is not None: - #TODO: LielinJiang support lmdb to accelerate io - pass - else: - self.paths = paired_paths_from_folder( - [self.lq_folder, self.gt_folder], ['lq', 'gt'], - self.filename_tmpl) - - def __getitem__(self, index): - scale = self.cfg['scale'] - - # Load gt and lq images. Dimension order: HWC; channel order: BGR; - # image range: [0, 1], float32. - gt_path = self.paths[index]['gt_path'] - lq_path = self.paths[index]['lq_path'] - - img_gt = cv2.imread(gt_path).astype(np.float32) / 255. - img_lq = cv2.imread(lq_path).astype(np.float32) / 255. - - # augmentation for training - if self.cfg['phase'] == 'train': - gt_size = self.cfg['gt_size'] - # random crop - img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, - gt_path) - # flip, rotation - img_gt, img_lq = augment([img_gt, img_lq], self.cfg['use_flip'], - self.cfg['use_rot']) - - # TODO: color space transform - # BGR to RGB, HWC to CHW, numpy to tensor - permute = transform.Permute() - img_gt = permute(img_gt) - img_lq = permute(img_lq) - return { - 'lq': img_lq, - 'gt': img_gt, - 'lq_path': lq_path, - 'gt_path': gt_path - } - - def __len__(self): - return len(self.paths) diff --git a/ppgan/datasets/transforms/builder.py b/ppgan/datasets/transforms/builder.py index 3e9f23bd52bbf7ca14dbcaf0dbd8635c8d2b1bb5..12b05a6c0524274e0711938c51e77ed855a056b2 100644 --- a/ppgan/datasets/transforms/builder.py +++ b/ppgan/datasets/transforms/builder.py @@ -24,14 +24,11 @@ class Compose(object): """ Composes several transforms together use for composing list of transforms together for a dataset transform. - Args: transforms (list): List of transforms to compose. - Returns: A compose object which is callable, __call__ for this Compose object will call each given :attr:`transforms` sequencely. - """ def __init__(self, transforms): self.transforms = transforms diff --git a/ppgan/datasets/unpaired_dataset.py b/ppgan/datasets/unpaired_dataset.py index 53ffbbe2b34813c53f769b2a452862b071577426..18f4767737957e950ca210e6b5af514e350f2389 100644 --- a/ppgan/datasets/unpaired_dataset.py +++ b/ppgan/datasets/unpaired_dataset.py @@ -12,80 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 import random import os.path -from .base_dataset import BaseDataset, get_transform -from .image_folder import make_dataset +from .base_dataset import BaseDataset from .builder import DATASETS -from .transforms.builder import build_transforms @DATASETS.register() class UnpairedDataset(BaseDataset): """ """ - def __init__(self, cfg): - """Initialize this dataset class. + def __init__(self, dataroot_a, dataroot_b, max_size, is_train, preprocess): + """Initialize unpaired dataset class. Args: - cfg (dict) -- stores all the experiment flags - """ - BaseDataset.__init__(self, cfg) - self.dir_A = os.path.join(cfg.dataroot, cfg.phase + - 'A') # create a path '/path/to/data/trainA' - self.dir_B = os.path.join(cfg.dataroot, cfg.phase + - 'B') # create a path '/path/to/data/trainB' - - self.A_paths = sorted(make_dataset( - self.dir_A, - cfg.max_dataset_size)) # load images from '/path/to/data/trainA' - self.B_paths = sorted(make_dataset( - self.dir_B, - cfg.max_dataset_size)) # load images from '/path/to/data/trainB' - self.A_size = len(self.A_paths) # get the size of dataset A - self.B_size = len(self.B_paths) # get the size of dataset B - btoA = self.cfg.direction == 'BtoA' - input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image - output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image - - self.transform_A = build_transforms(self.cfg.transforms) - self.transform_B = build_transforms(self.cfg.transforms) - - self.reset_paths() - - def reset_paths(self): - self.path_dict = {} + dataroot_a (str): Directory of dataset a. + dataroot_b (str): Directory of dataset b. + max_size (int): max size of dataset size. + is_train (int): whether in train mode. + preprocess (list[dict]): A sequence of data preprocess config. - def __getitem__(self, index): - """Return a data point and its metadata information. - - Parameters: - index (int) -- a random integer for data indexing - - Returns a dictionary that contains A, B, A_paths and B_paths - A (tensor) -- an image in the input domain - B (tensor) -- its corresponding image in the target domain - A_paths (str) -- image paths - B_paths (str) -- image paths """ - A_path = self.A_paths[ - index % self.A_size] # make sure index is within then range - if self.cfg.serial_batches: # make sure index is within then range - index_B = index % self.B_size - else: # randomize the index for domain B to avoid fixed pairs. - index_B = random.randint(0, self.B_size - 1) - B_path = self.B_paths[index_B] + super(UnpairedDataset, self).__init__(preprocess) + self.dir_A = os.path.join(dataroot_a) + self.dir_B = os.path.join(dataroot_b) + self.is_train = is_train + self.data_infos_a = self.prepare_data_infos(self.dir_A) + self.data_infos_b = self.prepare_data_infos(self.dir_B) + self.size_a = len(self.data_infos_a) + self.size_b = len(self.data_infos_b) + + def prepare_data_infos(self, dataroot): + """Load unpaired image paths of one domain. - A_img = cv2.cvtColor(cv2.imread(A_path), cv2.COLOR_BGR2RGB) - B_img = cv2.cvtColor(cv2.imread(B_path), cv2.COLOR_BGR2RGB) - # apply image transformation - A = self.transform_A(A_img) - B = self.transform_B(B_img) + Args: + dataroot (str): Path to the folder root for unpaired images of + one domain. - # return A, B - return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} + Returns: + list[dict]: List that contains unpaired image paths of one domain. + """ + data_infos = [] + paths = sorted(self.scan_folder(dataroot)) + for path in paths: + data_infos.append(dict(path=path)) + return data_infos + + def __getitem__(self, idx): + if self.is_train: + img_a_path = self.data_infos_a[idx % self.size_a]['path'] + idx_b = random.randint(0, self.size_b) + img_b_path = self.data_infos_b[idx_b]['path'] + datas = dict(A_path=img_a_path, B_path=img_b_path) + else: + img_a_path = self.data_infos_a[idx % self.size_a]['path'] + img_b_path = self.data_infos_b[idx % self.size_b]['path'] + datas = dict(A_path=img_a_path, B_path=img_b_path) + + if self.preprocess: + datas = self.preprocess(datas) + + return datas def __len__(self): """Return the total number of images in the dataset. @@ -93,4 +81,4 @@ class UnpairedDataset(BaseDataset): As we have two datasets with potentially different number of images, we take a maximum of """ - return max(self.A_size, self.B_size) + return max(self.size_a, self.size_b) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 7c8a472bf3a32bb8a34b7a0a6d913886750db464..7a1e1a81a9c676de498092162805973f7b22dd09 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -27,25 +27,78 @@ from ..models.builder import build_model from ..utils.visual import tensor2img, save_image from ..utils.filesystem import makedirs, save, load from ..utils.timer import TimeAverager -from ..metric.psnr_ssim import calculate_psnr, calculate_ssim -class Trainer: - def __init__(self, cfg): +class IterLoader: + def __init__(self, dataloader): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._epoch = 1 - # build train dataloader - self.train_dataloader = build_dataloader(cfg.dataset.train) + @property + def epoch(self): + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) - if 'lr_scheduler' in cfg.optimizer: - cfg.optimizer.lr_scheduler.step_per_epoch = len( - self.train_dataloader) + return data + + def __len__(self): + return len(self._dataloader) + + +class Trainer: + """ + # trainer calling logic: + # + # build_model || model(BaseModel) + # | || + # build_dataloader || dataloader + # | || + # model.setup_lr_schedulers || lr_scheduler + # | || + # model.setup_optimizers || optimizers + # | || + # train loop (model.setup_input + model.train_iter) || train loop + # | || + # print log (model.get_current_losses) || + # | || + # save checkpoint (model.nets) \/ + """ + def __init__(self, cfg): # build model - self.model = build_model(cfg) + self.model = build_model(cfg.model) # multiple gpus prepare if ParallelEnv().nranks > 1: self.distributed_data_parallel() + # build train dataloader + self.train_dataloader = build_dataloader(cfg.dataset.train) + self.iters_per_epoch = len(self.train_dataloader) + + # build lr scheduler + # TODO: has a better way? + if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler: + cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch + self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler) + + # build optimizers + self.optimizers = self.model.setup_optimizers(self.lr_schedulers, + cfg.optimizer) + + # build metrics + self.metrics = None + validate_cfg = cfg.get('validate', None) + if validate_cfg and 'metrics' in validate_cfg: + self.metrics = self.model.setup_metrics(validate_cfg['metrics']) + self.logger = logging.getLogger(__name__) self.enable_visualdl = cfg.get('enable_visualdl', False) if self.enable_visualdl: @@ -54,9 +107,18 @@ class Trainer: # base config self.output_dir = cfg.output_dir - self.epochs = cfg.epochs + self.epochs = cfg.get('epochs', None) + if self.epochs: + self.total_iters = self.epochs * self.iters_per_epoch + self.by_epoch = True + else: + self.by_epoch = False + self.total_iters = cfg.total_iters + self.start_epoch = 1 self.current_epoch = 1 + self.current_iter = 1 + self.inner_iter = 1 self.batch_id = 0 self.global_steps = 0 self.weight_interval = cfg.snapshot_config.interval @@ -69,10 +131,6 @@ class Trainer: self.local_rank = ParallelEnv().local_rank - # time count - self.steps_per_epoch = len(self.train_dataloader) - self.total_steps = self.epochs * self.steps_per_epoch - self.time_count = {} self.best_metric = {} @@ -85,117 +143,68 @@ class Trainer: reader_cost_averager = TimeAverager() batch_cost_averager = TimeAverager() - for epoch in range(self.start_epoch, self.epochs + 1): - self.current_epoch = epoch - start_time = step_start_time = time.time() - for i, data in enumerate(self.train_dataloader): - reader_cost_averager.record(time.time() - step_start_time) - - self.batch_id = i - # unpack data from dataset and apply preprocessing - # data input should be dict - self.model.set_input(data) - self.model.optimize_parameters() - - batch_cost_averager.record(time.time() - step_start_time, - num_samples=self.cfg.get( - 'batch_size', 1)) - if i % self.log_interval == 0: - self.data_time = reader_cost_averager.get_average() - self.step_time = batch_cost_averager.get_average() - self.ips = batch_cost_averager.get_ips_average() - self.print_log() - - reader_cost_averager.reset() - batch_cost_averager.reset() - - if i % self.visual_interval == 0: - self.visual('visual_train') - self.global_steps += 1 - step_start_time = time.time() - - self.logger.info( - 'train one epoch use time: {:.3f} seconds.'.format(time.time() - - start_time)) - if self.validate_interval > -1 and epoch % self.validate_interval: - self.validate() - self.model.lr_scheduler.step() - if epoch % self.weight_interval == 0: - self.save(epoch, 'weight', keep=-1) - self.save(epoch) - - def validate(self): - if not hasattr(self, 'val_dataloader'): - self.val_dataloader = build_dataloader(self.cfg.dataset.val, - is_train=False) + iter_loader = IterLoader(self.train_dataloader) - metric_result = {} + while self.current_iter < (self.total_iters + 1): + self.current_epoch = iter_loader.epoch + self.inner_iter = self.current_iter % self.iters_per_epoch - for i, data in enumerate(self.val_dataloader): - self.batch_id = i - - self.model.set_input(data) - self.model.test() - - visual_results = {} - current_paths = self.model.get_image_paths() - current_visuals = self.model.get_current_visuals() + start_time = step_start_time = time.time() + data = next(iter_loader) + reader_cost_averager.record(time.time() - step_start_time) + # unpack data from dataset and apply preprocessing + # data input should be dict + self.model.setup_input(data) + self.model.train_iter(self.optimizers) + + batch_cost_averager.record(time.time() - step_start_time, + num_samples=self.cfg.get( + 'batch_size', 1)) + if self.current_iter % self.log_interval == 0: + self.data_time = reader_cost_averager.get_average() + self.step_time = batch_cost_averager.get_average() + self.ips = batch_cost_averager.get_ips_average() + self.print_log() + + reader_cost_averager.reset() + batch_cost_averager.reset() + + if self.current_iter % self.visual_interval == 0: + self.visual('visual_train') + + step_start_time = time.time() - for j in range(len(current_paths)): - short_path = os.path.basename(current_paths[j]) - basename = os.path.splitext(short_path)[0] - for k, img_tensor in current_visuals.items(): - name = '%s_%s' % (basename, k) - visual_results.update({name: img_tensor[j]}) - if 'psnr' in self.cfg.validate.metrics: - if 'psnr' not in metric_result: - metric_result['psnr'] = calculate_psnr( - tensor2img(current_visuals['output'][j], (0., 1.)), - tensor2img(current_visuals['gt'][j], (0., 1.)), - **self.cfg.validate.metrics.psnr) - else: - metric_result['psnr'] += calculate_psnr( - tensor2img(current_visuals['output'][j], (0., 1.)), - tensor2img(current_visuals['gt'][j], (0., 1.)), - **self.cfg.validate.metrics.psnr) - if 'ssim' in self.cfg.validate.metrics: - if 'ssim' not in metric_result: - metric_result['ssim'] = calculate_ssim( - tensor2img(current_visuals['output'][j], (0., 1.)), - tensor2img(current_visuals['gt'][j], (0., 1.)), - **self.cfg.validate.metrics.ssim) - else: - metric_result['ssim'] += calculate_ssim( - tensor2img(current_visuals['output'][j], (0., 1.)), - tensor2img(current_visuals['gt'][j], (0., 1.)), - **self.cfg.validate.metrics.ssim) - - self.visual('visual_val', - visual_results=visual_results, - step=self.batch_id) + self.model.lr_scheduler.step() - if i % self.log_interval == 0: - self.logger.info('val iter: [%d/%d]' % - (i, len(self.val_dataloader))) + if self.by_epoch: + temp = self.current_epoch + else: + temp = self.current_iter + if self.validate_interval > -1 and temp % self.validate_interval == 0: + self.test() - for metric_name in metric_result.keys(): - metric_result[metric_name] /= len(self.val_dataloader.dataset) + if temp % self.weight_interval == 0: + self.save(temp, 'weight', keep=-1) + self.save(temp) - self.logger.info('Epoch {} validate end: {}'.format( - self.current_epoch, metric_result)) + self.current_iter += 1 def test(self): if not hasattr(self, 'test_dataloader'): self.test_dataloader = build_dataloader(self.cfg.dataset.test, - is_train=False) + is_train=False, + distributed=False) + + if self.metrics: + for metric in self.metrics.values(): + metric.reset() # data[0]: img, data[1]: img path index # test batch size must be 1 for i, data in enumerate(self.test_dataloader): - self.batch_id = i - self.model.set_input(data) - self.model.test() + self.model.setup_input(data) + self.model.test_iter(metrics=self.metrics) visual_results = {} current_paths = self.model.get_image_paths() @@ -217,11 +226,23 @@ class Trainer: self.logger.info('Test iter: [%d/%d]' % (i, len(self.test_dataloader))) + if self.metrics: + for metric_name, metric in self.metrics.items(): + self.logger.info("Metric {}: {:.4f}".format( + metric_name, metric.accumulate())) + def print_log(self): losses = self.model.get_current_losses() - message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id) - message += '%s: %.6f ' % ('lr', self.current_learning_rate) + message = '' + if self.by_epoch: + message += 'Epoch: %d/%d, iter: %d/%d ' % ( + self.current_epoch, self.epochs, self.inner_iter, + self.iters_per_epoch) + else: + message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters) + + message += f'lr: {self.current_learning_rate:.3e} ' for k, v in losses.items(): message += '%s: %.3f ' % (k, v) @@ -238,9 +259,7 @@ class Trainer: message += 'ips: %.5f images/s ' % self.ips if hasattr(self, 'step_time'): - cur_step = self.steps_per_epoch * (self.current_epoch - - 1) + self.batch_id - eta = self.step_time * (self.total_steps - cur_step - 1) + eta = self.step_time * (self.total_iters - self.current_iter - 1) eta_str = str(datetime.timedelta(seconds=int(eta))) message += f'eta: {eta_str}' @@ -274,6 +293,7 @@ class Trainer: min_max = self.cfg.get('min_max', None) if min_max is None: min_max = (-1., 1.) + image_num = self.cfg.get('image_num', None) if (image_num is None) or (not self.enable_visualdl): image_num = 1 @@ -345,6 +365,14 @@ class Trainer: state_dicts = load(weight_path) for net_name, net in self.model.nets.items(): + if net_name in state_dicts: + net.set_state_dict(state_dicts[net_name]) + self.logger.info( + 'Loaded pretrained weight for net {}'.format(net_name)) + else: + self.logger.warning( + 'Can not find state dict of net {}. Skip load pretrained weight for net {}' + .format(net_name, net_name)) net.set_state_dict(state_dicts[net_name]) def close(self): diff --git a/ppgan/metric/README.md b/ppgan/metric/README.md deleted file mode 100644 index 08fe7e700a48b8e8a9ad3053d03138498c1f5b61..0000000000000000000000000000000000000000 --- a/ppgan/metric/README.md +++ /dev/null @@ -1,19 +0,0 @@ -English (./README.md) - -# Usage - -To compute the FID score between two datasets, where images of each dataset are contained in an individual folder: - -wget https://paddlegan.bj.bcebos.com/InceptionV3.pdparams -``` -python test_fid_score.py --image_data_path1 /path/to/dataset1 --image_data_path2 /path/to/dataset2 --inference_model ./InceptionV3.pdparams -``` - -### Inception-V3 weights converted from torchvision - -Download: https://aistudio.baidu.com/aistudio/datasetdetail/51890 - -This model weights file is converted from official torchvision inception-v3 model. And both BigGAN and StarGAN-v2 is using it to calculate FID score. - -Note that this model weights is different from above one (which is converted from tensorflow unofficial version) - diff --git a/ppgan/metric/metric_util.py b/ppgan/metric/metric_util.py deleted file mode 100644 index 857304d1be40a82dedec190eadea4778ef5b43c8..0000000000000000000000000000000000000000 --- a/ppgan/metric/metric_util.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - - -def reorder_image(img, input_order='HWC'): - """Reorder images to 'HWC' order. - - If the input_order is (h, w), return (h, w, 1); - If the input_order is (c, h, w), return (h, w, c); - If the input_order is (h, w, c), return as it is. - - Args: - img (ndarray): Input image. - input_order (str): Whether the input order is 'HWC' or 'CHW'. - If the input image shape is (h, w), input_order will not have - effects. Default: 'HWC'. - - Returns: - ndarray: reordered image. - """ - - if input_order not in ['HWC', 'CHW']: - raise ValueError( - f'Wrong input_order {input_order}. Supported input_orders are ' - "'HWC' and 'CHW'") - if len(img.shape) == 2: - img = img[..., None] - return img - if input_order == 'CHW': - img = img.transpose(1, 2, 0) - return img - - -def bgr2ycbcr(img, y_only=False): - """Convert a BGR image to YCbCr image. - - The bgr version of rgb2ycbcr. - It implements the ITU-R BT.601 conversion for standard-definition - television. See more details in - https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. - - It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. - In OpenCV, it implements a JPEG conversion. See more details in - https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. - - Args: - img (ndarray): The input image. It accepts: - 1. np.uint8 type with range [0, 255]; - 2. np.float32 type with range [0, 1]. - y_only (bool): Whether to only return Y channel. Default: False. - - Returns: - ndarray: The converted YCbCr image. The output image has the same type - and range as input image. - """ - img_type = img.dtype - #img = _convert_input_type_range(img) - if y_only: - out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 - else: - out_img = np.matmul( - img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) + [16, 128, 128] - #out_img = _convert_output_type_range(out_img, img_type) - return out_img - - -def to_y_channel(img): - """Change to Y channel of YCbCr. - - Args: - img (ndarray): Images with range [0, 255]. - - Returns: - (ndarray): Images with range [0, 255] (float type) without round. - """ - img = img.astype(np.float32) / 255. - if img.ndim == 3 and img.shape[2] == 3: - img = bgr2ycbcr(img, y_only=True) - img = img[..., None] - return img * 255. diff --git a/ppgan/metric/test_fid_score.py b/ppgan/metric/test_fid_score.py deleted file mode 100644 index 027a051012c11bf309605fccdd19f80b75b5fb01..0000000000000000000000000000000000000000 --- a/ppgan/metric/test_fid_score.py +++ /dev/null @@ -1,67 +0,0 @@ -#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -import argparse -from compute_fid import * - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--image_data_path1', - type=str, - default='./real', - help='path of image data') - parser.add_argument('--image_data_path2', - type=str, - default='./fake', - help='path of image data') - parser.add_argument('--inference_model', - type=str, - default='./pretrained/params_inceptionV3', - help='path of inference_model.') - parser.add_argument('--use_gpu', - type=bool, - default=True, - help='default use gpu.') - parser.add_argument('--batch_size', - type=int, - default=1, - help='sample number in a batch for inference.') - parser.add_argument( - '--style', - type=str, - help='calculation style: stargan or default (gan-compression style)') - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - path1 = args.image_data_path1 - path2 = args.image_data_path2 - paths = (path1, path2) - inference_model_path = args.inference_model - batch_size = args.batch_size - - fid_value = calculate_fid_given_paths(paths, - inference_model_path, - batch_size, - args.use_gpu, - 2048, - style=args.style) - print('FID: ', fid_value) - - -if __name__ == "__main__": - main() diff --git a/ppgan/metric/__init__.py b/ppgan/metrics/__init__.py similarity index 89% rename from ppgan/metric/__init__.py rename to ppgan/metrics/__init__.py index 9e83ec486fc6232d90d965425588645bc1204386..88afbd621fa8b9b7d4f6405e14aae406f5caa0f0 100644 --- a/ppgan/metric/__init__.py +++ b/ppgan/metrics/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .psnr_ssim import PSNR, SSIM +from .builder import build_metric diff --git a/ppgan/datasets/transforms/__init__.py b/ppgan/metrics/builder.py similarity index 72% rename from ppgan/datasets/transforms/__init__.py rename to ppgan/metrics/builder.py index 3590edea45b855649e54d367ee265e37c85b61b8..96822314c4ac2480490fc90c141fe179b791cc22 100644 --- a/ppgan/datasets/transforms/__init__.py +++ b/ppgan/metrics/builder.py @@ -12,4 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip, Add, ResizeToScale +import copy +import paddle + +from ..utils.registry import Registry + +METRICS = Registry("METRIC") + + +def build_metric(cfg): + cfg_ = cfg.copy() + name = cfg_.pop('name', None) + metric = METRICS.get(name)(**cfg_) + return metric diff --git a/ppgan/metric/compute_fid.py b/ppgan/metrics/compute_fid.py similarity index 100% rename from ppgan/metric/compute_fid.py rename to ppgan/metrics/compute_fid.py diff --git a/ppgan/metric/inception.py b/ppgan/metrics/inception.py similarity index 100% rename from ppgan/metric/inception.py rename to ppgan/metrics/inception.py diff --git a/ppgan/metric/psnr_ssim.py b/ppgan/metrics/psnr_ssim.py similarity index 56% rename from ppgan/metric/psnr_ssim.py rename to ppgan/metrics/psnr_ssim.py index d5c371cf67c107dc5385f919a4958a484b38ee40..fb362cb86eb8e50b974d90b14f029ad6961ead35 100644 --- a/ppgan/metric/psnr_ssim.py +++ b/ppgan/metrics/psnr_ssim.py @@ -14,8 +14,59 @@ import cv2 import numpy as np +import paddle -from .metric_util import reorder_image, to_y_channel +from .builder import METRICS + + +@METRICS.register() +class PSNR(paddle.metric.Metric): + def __init__(self, crop_border, input_order='HWC', test_y_channel=False): + self.crop_border = crop_border + self.input_order = input_order + self.test_y_channel = test_y_channel + self.reset() + + def reset(self): + self.results = [] + + def update(self, preds, gts): + if not isinstance(preds, (list, tuple)): + preds = [preds] + + if not isinstance(gts, (list, tuple)): + gts = [gts] + + for pred, gt in zip(preds, gts): + value = calculate_psnr(pred, gt, self.crop_border, self.input_order, + self.test_y_channel) + self.results.append(value) + + def accumulate(self): + if len(self.results) <= 0: + return 0. + return np.mean(self.results) + + def name(self): + return 'PSNR' + + +@METRICS.register() +class SSIM(PSNR): + def update(self, preds, gts): + if not isinstance(preds, (list, tuple)): + preds = [preds] + + if not isinstance(gts, (list, tuple)): + gts = [gts] + + for pred, gt in zip(preds, gts): + value = calculate_ssim(pred, gt, self.crop_border, self.input_order, + self.test_y_channel) + self.results.append(value) + + def name(self): + return 'SSIM' def calculate_psnr(img1, @@ -46,6 +97,8 @@ def calculate_psnr(img1, raise ValueError( f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = img1.copy().astype('float32') + img2 = img2.copy().astype('float32') img1 = reorder_image(img1, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) @@ -134,6 +187,10 @@ def calculate_ssim(img1, raise ValueError( f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + + img1 = img1.copy().astype('float32')[..., ::-1] + img2 = img2.copy().astype('float32')[..., ::-1] + img1 = reorder_image(img1, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) @@ -149,3 +206,81 @@ def calculate_ssim(img1, for i in range(img1.shape[2]): ssims.append(_ssim(img1[..., i], img2[..., i])) return np.array(ssims).mean() + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + return img + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + return out_img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 12a78ff7c0734e890e4c18f71075ae9c69f11e64..a675130094955e110121ed3a89e4097cefaed09d 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -16,9 +16,9 @@ from .base_model import BaseModel from .gan_model import GANModel from .cycle_gan_model import CycleGANModel from .pix2pix_model import Pix2PixModel -from .srgan_model import SRGANModel -from .sr_model import SRModel +from .sr_model import BaseSRModel from .makeup_model import MakeupModel +from .esrgan_model import ESRGAN from .ugatit_model import UGATITModel from .dc_gan_model import DCGANModel from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel diff --git a/ppgan/models/animeganv2_model.py b/ppgan/models/animeganv2_model.py index 819a069d96c146a5fb9b11f9d0f004f65f24461e..9f768c7abf125350a678935e7c2f3fcd2ee71105 100644 --- a/ppgan/models/animeganv2_model.py +++ b/ppgan/models/animeganv2_model.py @@ -19,7 +19,7 @@ from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator -from .losses import GANLoss +from .criterions.gan_loss import GANLoss from ..modules.caffevgg import CaffeVGG19 from ..solver import build_optimizer from ..modules.init import init_weights diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py index 09b0c652a414b073dadd651d3bb1582c509fccae..6b92db7f3d0d0aab4078762e212f554713620a52 100644 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -19,24 +19,40 @@ import numpy as np from collections import OrderedDict from abc import ABC, abstractmethod -from ..solver.lr_scheduler import build_lr_scheduler +from .criterions.builder import build_criterion +from ..solver import build_lr_scheduler, build_optimizer +from ..metrics import build_metric +from ..utils.visual import tensor2img class BaseModel(ABC): """This class is an abstract base class (ABC) for models. To create a subclass, you need to implement the following five functions: - -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). - -- : unpack data from dataset and apply preprocessing. - -- : produce intermediate results. - -- : calculate losses, gradients, and update network weights. - -- : (optionally) add model-specific options and set default options. + -- <__init__>: initialize the class. + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + + # trainer training logic: + # + # build_model || model(BaseModel) + # | || + # build_dataloader || dataloader + # | || + # model.setup_lr_schedulers || lr_scheduler + # | || + # model.setup_optimizers || optimizers + # | || + # train loop (model.setup_input + model.train_iter) || train loop + # | || + # print log (model.get_current_losses) || + # | || + # save checkpoint (model.nets) \/ + """ - def __init__(self, cfg): + def __init__(self): """Initialize the BaseModel class. - Args: - cfg (Dict)-- configs of Model. - When creating your custom class, you need to implement your own initialization. In this function, you should first call Then, you need to define four lists: @@ -47,57 +63,85 @@ class BaseModel(ABC): If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. """ - self.cfg = cfg - self.is_train = cfg.is_train - self.save_dir = os.path.join( - cfg.output_dir, - cfg.model.name) # save all the checkpoints to save_dir - self.losses = OrderedDict() self.nets = OrderedDict() - self.visual_items = OrderedDict() self.optimizers = OrderedDict() - self.image_paths = [] - self.metric = 0 # used for learning rate policy 'plateau' + self.metrics = OrderedDict() + self.losses = OrderedDict() + self.visual_items = OrderedDict() @abstractmethod - def set_input(self, input): + def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. - Parameters: + Args: input (dict): includes the data itself and its metadata information. """ pass @abstractmethod def forward(self): - """Run forward pass; called by both functions and .""" + """Run forward pass; called by both functions and .""" pass @abstractmethod - def optimize_parameters(self): + def train_iter(self, optims=None): """Calculate losses, gradients, and update network weights; called in every training iteration""" pass - def build_lr_scheduler(self): - self.lr_scheduler = build_lr_scheduler(self.cfg.lr_scheduler) + def test_iter(self, metrics=None): + """Calculate metrics; called in every test iteration""" + self.eval() + with paddle.no_grad(): + self.forward() + self.train() + + def setup_train_mode(self, is_train): + self.is_train = is_train + + def setup_lr_schedulers(self, cfg): + self.lr_scheduler = build_lr_scheduler(cfg) + return self.lr_scheduler + + def setup_optimizers(self, lr, cfg): + if cfg.get('name', None): + cfg_ = cfg.copy() + net_names = cfg_.pop('net_names') + parameters = [] + for net_name in net_names: + parameters += self.nets[net_name].parameters() + self.optimizers['optim'] = build_optimizer(cfg_, lr, parameters) + else: + for opt_name, opt_cfg in cfg.items(): + cfg_ = opt_cfg.copy() + net_names = cfg_.pop('net_names') + parameters = [] + for net_name in net_names: + parameters += self.nets[net_name].parameters() + self.optimizers[opt_name] = build_optimizer( + cfg_, lr, parameters) + + return self.optimizers + + def setup_metrics(self, cfg): + if isinstance(list(cfg.values())[0], dict): + for metric_name, cfg_ in cfg.items(): + self.metrics[metric_name] = build_metric(cfg_) + else: + metric = build_metric(cfg) + self.metrics[metric.__class__.__name__] = metric + + return self.metrics def eval(self): - """Make models eval mode during test time""" - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, 'net' + name) - net.eval() - - def test(self): - """Forward function used in test time. + """Make nets eval mode during test time""" + for net in self.nets.values(): + net.eval() - This function wraps function in no_grad() so we don't save intermediate steps for backprop - It also calls to produce additional visualization results - """ - with paddle.no_grad(): - self.forward() - self.compute_visuals() + def train(self): + """Make nets train mode during train time""" + for net in self.nets.values(): + net.train() def compute_visuals(self): """Calculate additional output images for visdom and HTML visualization""" @@ -118,8 +162,8 @@ class BaseModel(ABC): def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Args: - nets (network list) -- a list of networks - requires_grad (bool) -- whether the networks require gradients or not + nets (network list): a list of networks + requires_grad (bool): whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] diff --git a/ppgan/models/builder.py b/ppgan/models/builder.py index 6ed338ea6ecb7178e5068df50380494f23de14b3..81b0c2a26a48e9d97a883a7463e1d0fff88c75f5 100644 --- a/ppgan/models/builder.py +++ b/ppgan/models/builder.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import paddle from ..utils.registry import Registry @@ -20,5 +21,7 @@ MODELS = Registry("MODEL") def build_model(cfg): - model = MODELS.get(cfg.model.name)(cfg) + cfg_ = cfg.copy() + name = cfg_.pop('name', None) + model = MODELS.get(name)(**cfg_) return model diff --git a/ppgan/models/criterions/__init__.py b/ppgan/models/criterions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c49542b201205ae4db830366d50e1553a5dc723 --- /dev/null +++ b/ppgan/models/criterions/__init__.py @@ -0,0 +1,5 @@ +from .gan_loss import GANLoss +from .perceptual_loss import PerceptualLoss +from .pixel_loss import L1Loss, MSELoss + +from .builder import build_criterion diff --git a/ppgan/models/criterions/builder.py b/ppgan/models/criterions/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..31de2ff421f0dfc771801b84a94e70a4f699d218 --- /dev/null +++ b/ppgan/models/criterions/builder.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils.registry import Registry + +CRITERIONS = Registry('CRITERION') + + +def build_criterion(cfg): + cfg_ = cfg.copy() + name = cfg_.pop('name') + try: + criterion = CRITERIONS.get(name)(**cfg_) + except Exception as e: + cls_ = CRITERIONS.get(name) + raise RuntimeError('class {} {}'.format(cls_.__name__, e)) + return criterion diff --git a/ppgan/models/losses.py b/ppgan/models/criterions/gan_loss.py similarity index 82% rename from ppgan/models/losses.py rename to ppgan/models/criterions/gan_loss.py index 9139266894e44d3a284d6347d6043b33a1c43c1c..d3fbcda49c1e25b397522f8610eb11817d0e3f13 100644 --- a/ppgan/models/losses.py +++ b/ppgan/models/criterions/gan_loss.py @@ -16,29 +16,40 @@ import numpy as np import paddle import paddle.nn as nn +from .builder import CRITERIONS import paddle.nn.functional as F +@CRITERIONS.register() class GANLoss(nn.Layer): """Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input. """ - def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + def __init__(self, + gan_mode, + target_real_label=1.0, + target_fake_label=0.0, + loss_weight=1.0): """ Initialize the GANLoss class. - Parameters: - gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. - target_real_label (bool) - - label for a real image - target_fake_label (bool) - - label of a fake image + Args: + gan_mode (str): the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool): label for a real image + target_fake_label (bool): label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() + # when loss weight less than zero return None + if loss_weight <= 0: + return None + self.target_real_label = target_real_label self.target_fake_label = target_fake_label + self.loss_weight = loss_weight self.gan_mode = gan_mode if gan_mode == 'lsgan': @@ -53,7 +64,7 @@ class GANLoss(nn.Layer): def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. - Parameters: + Args: prediction (tensor) - - tpyically the prediction from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images @@ -75,16 +86,19 @@ class GANLoss(nn.Layer): dtype='float32') target_tensor = self.target_fake_tensor - # target_tensor.stop_gradient = True return target_tensor - def __call__(self, prediction, target_is_real, is_updating_D=None): + def __call__(self, + prediction, + target_is_real, + is_disc=False, + is_updating_D=None): """Calculate loss given Discriminator's output and grount truth labels. - Parameters: + Args: prediction (tensor) - - tpyically the prediction output from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images - is_updating_D (bool) - - if we are in updating D step or not + is_updating_D (bool) - - if we are in updating D step or not Returns: the calculated loss. @@ -108,4 +122,5 @@ class GANLoss(nn.Layer): loss = F.softplus(-prediction).mean() else: loss = F.softplus(prediction).mean() - return loss + + return loss if is_disc else loss * self.loss_weight diff --git a/ppgan/models/criterions/perceptual_loss.py b/ppgan/models/criterions/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f112b9f5c14eeaf0e5ae726aa91beb4c2d8bd5b1 --- /dev/null +++ b/ppgan/models/criterions/perceptual_loss.py @@ -0,0 +1,200 @@ +import paddle +import paddle.nn as nn +import paddle.vision.models.vgg as vgg + +from ppgan.utils.download import get_path_from_url +from .builder import CRITERIONS + + +class PerceptualVGG(nn.Layer): + """VGG network used in calculating perceptual loss. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): According to the name in this list, + forward function will return the corresponding features. This + list contains the name each layer in `vgg.feature`. An example + of this list is ['4', '10']. + vgg_tyep (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. + Importantly, the input feature must in the range [0, 1]. + Default: True. + pretrained_url (str): Path for pretrained weights. Default: + """ + def __init__( + self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + pretrained_url='https://paddlegan.bj.bcebos.com/model/vgg19.pdparams' + ): + super(PerceptualVGG, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + + # get vgg model and load pretrained vgg weight + _vgg = getattr(vgg, vgg_type)() + + if pretrained_url: + weight_path = get_path_from_url(pretrained_url) + state_dict = paddle.load(weight_path) + _vgg.load_dict(state_dict) + print('PerceptualVGG loaded pretrained weight.') + + num_layers = max(map(int, layer_name_list)) + 1 + assert len(_vgg.features) >= num_layers + + # only borrow layers that will be used from _vgg to avoid unused params + self.vgg_layers = nn.Sequential( + *list(_vgg.features.children())[:num_layers]) + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer( + 'mean', + paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1])) + # the std is for image with range [-1, 1] + self.register_buffer( + 'std', + paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1])) + + for v in self.vgg_layers.parameters(): + v.trainable = False + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for name, module in self.vgg_layers.named_children(): + x = module(x) + if name in self.layer_name_list: + output[name] = x.clone() + return output + + +@CRITERIONS.register() +class PerceptualLoss(nn.Layer): + """Perceptual loss with commonly used style loss. + + Args: + layers_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the + 5th, 10th and 18th feature layer will be extracted with weight 1.0 + in calculting losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplified by the + weight. Default: 1.0. + style_weight (flaot): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplified by the weight. + Default: 1.0. + norm_img (bool): If True, the image will be normed to [0, 1]. Note that + this is different from the `use_input_norm` which norm the input in + in forward fucntion of vgg according to the statistics of dataset. + Importantly, the input image must be in range [-1, 1]. + pretrained (str): Path for pretrained weights. Default: + + """ + def __init__( + self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + perceptual_weight=1.0, + style_weight=1.0, + norm_img=True, + pretrained='https://paddlegan.bj.bcebos.com/model/vgg19.pdparams', + criterion='l1'): + super(PerceptualLoss, self).__init__() + # when loss weight less than zero return None + if perceptual_weight <= 0 and style_weight <= 0: + return None + + self.norm_img = norm_img + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + pretrained_url=pretrained) + + if criterion == 'l1': + self.criterion = nn.L1Loss() + else: + raise NotImplementedError( + f'{criterion} criterion has not been supported in' + ' this version.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + + if self.norm_img: + x = (x + 1.) * 0.5 + gt = (gt + 1.) * 0.5 + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate preceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + percep_loss += self.criterion( + x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + style_loss += self.criterion(self._gram_mat( + x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (paddle.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + paddle.Tensor: Gram matrix. + """ + (n, c, h, w) = x.shape + features = x.reshape([n, c, w * h]) + features_t = features.transpose([1, 2]) + gram = features.bmm(features_t) / (c * h * w) + return gram diff --git a/ppgan/models/criterions/pixel_loss.py b/ppgan/models/criterions/pixel_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4c949766d8d023e001c8a7c1ebabb6b56897f488 --- /dev/null +++ b/ppgan/models/criterions/pixel_loss.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +from .builder import CRITERIONS + + +@CRITERIONS.register() +class L1Loss(): + """L1 (mean absolute error, MAE) loss. + + Args: + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + + """ + def __init__(self, reduction='mean', loss_weight=1.0): + # when loss weight less than zero return None + if loss_weight <= 0: + return None + self._l1_loss = nn.L1Loss(reduction) + self.loss_weight = loss_weight + self.reduction = reduction + + def __call__(self, pred, target, **kwargs): + """Forward Function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * self._l1_loss(pred, target) + + +@CRITERIONS.register() +class MSELoss(): + """MSE (L2) loss. + + Args: + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + + """ + def __init__(self, reduction='mean', loss_weight=1.0): + # when loss weight less than zero return None + if loss_weight <= 0: + return None + self._l2_loss = nn.MSELoss(reduction) + self.loss_weight = loss_weight + self.reduction = reduction + + def __call__(self, pred, target, **kwargs): + """Forward Function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * self._l2_loss(pred, target) + + +@CRITERIONS.register() +class BCEWithLogitsLoss(): + """BCE loss. + + Args: + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + """ + def __init__(self, reduction='mean', loss_weight=1.0): + # when loss weight less than zero return None + if loss_weight <= 0: + return None + self._bce_loss = nn.BCEWithLogitsLoss(reduction=reduction) + self.loss_weight = loss_weight + self.reduction = reduction + + def __call__(self, pred, target, **kwargs): + """Forward Function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * self._bce_loss(pred, target) diff --git a/ppgan/models/cycle_gan_model.py b/ppgan/models/cycle_gan_model.py index c981dbf6c92a0603059804904f12cdcf54619ef5..6d1c3f0950b96dc9588ab598a572cc9497c302d3 100644 --- a/ppgan/models/cycle_gan_model.py +++ b/ppgan/models/cycle_gan_model.py @@ -18,9 +18,8 @@ from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator -from .losses import GANLoss +from .criterions import build_criterion -from ..solver import build_optimizer from ..modules.init import init_weights from ..utils.image_pool import ImagePool @@ -30,61 +29,62 @@ class CycleGANModel(BaseModel): """ This class implements the CycleGAN model, for learning image-to-image translation without paired data. - The model training requires '--dataset_mode unaligned' dataset. - By default, it uses a '--netG resnet_9blocks' ResNet generator, - a '--netD basic' discriminator (PatchGAN introduced by pix2pix), - and a least-square GANs objective ('--gan_mode lsgan'). - CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf """ - def __init__(self, cfg): + def __init__(self, + generator, + discriminator=None, + cycle_criterion=None, + idt_criterion=None, + gan_criterion=None, + pool_size=50, + direction='a2b', + lambda_a=10., + lambda_b=10.): """Initialize the CycleGAN class. - Parameters: - opt (config)-- stores all the experiment flags; needs to be a subclass of Dict + Args: + generator (dict): config of generator. + discriminator (dict): config of discriminator. + cycle_criterion (dict): config of cycle criterion. """ - super(CycleGANModel, self).__init__(cfg) + super(CycleGANModel, self).__init__() + + self.direction = direction - # define networks (both Generators and discriminators) + self.lambda_a = lambda_a + self.lambda_b = lambda_b + # define generators # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) - self.nets['netG_A'] = build_generator(cfg.model.generator) - self.nets['netG_B'] = build_generator(cfg.model.generator) + self.nets['netG_A'] = build_generator(generator) + self.nets['netG_B'] = build_generator(generator) init_weights(self.nets['netG_A']) init_weights(self.nets['netG_B']) - if self.is_train: # define discriminators - self.nets['netD_A'] = build_discriminator(cfg.model.discriminator) - self.nets['netD_B'] = build_discriminator(cfg.model.discriminator) + # define discriminators + if discriminator: + self.nets['netD_A'] = build_discriminator(discriminator) + self.nets['netD_B'] = build_discriminator(discriminator) init_weights(self.nets['netD_A']) init_weights(self.nets['netD_B']) - if self.is_train: - if cfg.lambda_identity > 0.0: # only works when input and output images have the same number of channels - assert ( - cfg.dataset.train.input_nc == cfg.dataset.train.output_nc) - # create image buffer to store previously generated images - self.fake_A_pool = ImagePool(cfg.dataset.train.pool_size) - # create image buffer to store previously generated images - self.fake_B_pool = ImagePool(cfg.dataset.train.pool_size) - # define loss functions - self.criterionGAN = GANLoss(cfg.model.gan_mode) - self.criterionCycle = paddle.nn.L1Loss() - self.criterionIdt = paddle.nn.L1Loss() - - self.build_lr_scheduler() - self.optimizers['optimizer_G'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netG_A'].parameters() + - self.nets['netG_B'].parameters()) - self.optimizers['optimizer_D'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netD_A'].parameters() + - self.nets['netD_B'].parameters()) - - def set_input(self, input): + # create image buffer to store previously generated images + self.fake_A_pool = ImagePool(pool_size) + # create image buffer to store previously generated images + self.fake_B_pool = ImagePool(pool_size) + + # define loss functions + if gan_criterion: + self.gan_criterion = build_criterion(gan_criterion) + + if cycle_criterion: + self.cycle_criterion = build_criterion(cycle_criterion) + + if idt_criterion: + self.idt_criterion = build_criterion(idt_criterion) + + def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Args: @@ -92,8 +92,8 @@ class CycleGANModel(BaseModel): The option 'direction' can be used to swap domain A and domain B. """ - mode = 'train' if self.is_train else 'test' - AtoB = self.cfg.dataset[mode].direction == 'AtoB' + + AtoB = self.direction == 'a2b' if AtoB: if 'A' in input: @@ -134,20 +134,22 @@ class CycleGANModel(BaseModel): def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator - Parameters: - netD (network) -- the discriminator D - real (tensor array) -- real images - fake (tensor array) -- images generated by a generator + Args: + netD (Layer): the discriminator D + real (paddle.Tensor): real images + fake (paddle.Tensor): images generated by a generator + + Return: + the discriminator loss. - Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) - loss_D_real = self.criterionGAN(pred_real, True) + loss_D_real = self.gan_criterion(pred_real, True) # Fake pred_fake = netD(fake.detach()) - loss_D_fake = self.criterionGAN(pred_fake, False) + loss_D_fake = self.gan_criterion(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 @@ -170,16 +172,13 @@ class CycleGANModel(BaseModel): def backward_G(self): """Calculate the loss for generators G_A and G_B""" - lambda_idt = self.cfg.lambda_identity - lambda_A = self.cfg.lambda_A - lambda_B = self.cfg.lambda_B # Identity loss - if lambda_idt > 0: + if self.idt_criterion: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.nets['netG_A'](self.real_B) - self.loss_idt_A = self.criterionIdt( - self.idt_A, self.real_B) * lambda_B * lambda_idt + self.loss_idt_A = self.idt_criterion(self.idt_A, + self.real_B) * self.lambda_b # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.nets['netG_B'](self.real_A) @@ -187,24 +186,24 @@ class CycleGANModel(BaseModel): self.visual_items['idt_A'] = self.idt_A self.visual_items['idt_B'] = self.idt_B - self.loss_idt_B = self.criterionIdt( - self.idt_B, self.real_A) * lambda_A * lambda_idt + self.loss_idt_B = self.idt_criterion(self.idt_B, + self.real_A) * self.lambda_a else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) - self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_B), - True) + self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_B), + True) # GAN loss D_B(G_B(B)) - self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_A), - True) + self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_A), + True) # Forward cycle loss || G_B(G_A(A)) - A|| - self.loss_cycle_A = self.criterionCycle(self.rec_A, - self.real_A) * lambda_A + self.loss_cycle_A = self.cycle_criterion(self.rec_A, + self.real_A) * self.lambda_a # Backward cycle loss || G_A(G_B(B)) - B|| - self.loss_cycle_B = self.criterionCycle(self.rec_B, - self.real_B) * lambda_B + self.loss_cycle_B = self.cycle_criterion(self.rec_B, + self.real_B) * self.lambda_b self.losses['G_idt_A_loss'] = self.loss_idt_A self.losses['G_idt_B_loss'] = self.loss_idt_B @@ -217,7 +216,7 @@ class CycleGANModel(BaseModel): self.loss_G.backward() - def optimize_parameters(self): + def train_iter(self, optimizers=None): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward # compute fake images and reconstruction images. @@ -227,19 +226,19 @@ class CycleGANModel(BaseModel): self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], False) # set G_A and G_B's gradients to zero - self.optimizers['optimizer_G'].clear_grad() + optimizers['optimG'].clear_grad() # calculate gradients for G_A and G_B self.backward_G() # update G_A and G_B's weights - self.optimizers['optimizer_G'].step() + self.optimizers['optimG'].step() # D_A and D_B self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True) # set D_A and D_B's gradients to zero - self.optimizers['optimizer_D'].clear_grad() + optimizers['optimD'].clear_grad() # calculate gradients for D_A self.backward_D_A() # calculate graidents for D_B self.backward_D_B() # update D_A and D_B's weights - self.optimizers['optimizer_D'].step() + optimizers['optimD'].step() diff --git a/ppgan/models/dc_gan_model.py b/ppgan/models/dc_gan_model.py index e279527357056ba65d915e8ac73727923014195d..b13e494af2d83d34aecc878ccaa8e505d7327796 100644 --- a/ppgan/models/dc_gan_model.py +++ b/ppgan/models/dc_gan_model.py @@ -18,70 +18,54 @@ from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator -from .losses import GANLoss - -from ..solver import build_optimizer +from .criterions import build_criterion from ..modules.init import init_weights @MODELS.register() class DCGANModel(BaseModel): - """ This class implements the DCGAN model, for learning a distribution from input images. - - The model training requires dataset. - By default, it uses a '--netG DCGenerator' generator, - a '--netD DCDiscriminator' discriminator, - and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). - + """ + This class implements the DCGAN model, for learning a distribution from input images. DCGAN paper: https://arxiv.org/pdf/1511.06434 """ - def __init__(self, cfg): + def __init__(self, generator, discriminator=None, gan_criterion=None): """Initialize the DCGAN class. - - Parameters: - opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict + Args: + generator (dict): config of generator. + discriminator (dict): config of discriminator. + pixel_criterion (dict): config of pixel criterion. + gan_criterion (dict): config of gan criterion. """ - super(DCGANModel, self).__init__(cfg) + super(DCGANModel, self).__init__() + self.gen_cfg = generator # define networks (both generator and discriminator) - self.nets['netG'] = build_generator(cfg.model.generator) + self.nets['netG'] = build_generator(generator) init_weights(self.nets['netG']) - self.cfg = cfg + if self.is_train: - self.nets['netD'] = build_discriminator(cfg.model.discriminator) + self.nets['netD'] = build_discriminator(discriminator) init_weights(self.nets['netD']) - if self.is_train: - self.losses = {} - # define loss functions - self.criterionGAN = GANLoss(cfg.model.gan_mode) - - # build optimizers - self.build_lr_scheduler() - self.optimizers['optimizer_G'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netG'].parameters()) - self.optimizers['optimizer_D'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netD'].parameters()) - - def set_input(self, input): + if gan_criterion: + self.gan_criterion = build_criterion(gan_criterion) + + def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. - Parameters: + Args: input (dict): include the data itself and its metadata information. """ # get 1-channel gray image, or 3-channel color image - self.real = paddle.to_tensor(input['A'][:,0:self.cfg.model.generator.input_nc,:,:]) - self.image_paths = input['A_paths'] + self.real = paddle.to_tensor(input['A']) + self.image_paths = input['A_path'] def forward(self): - """Run forward pass; called by both functions and .""" + """Run forward pass; called by both functions and .""" # generate random noise and fake image - self.z = paddle.rand(shape=(self.real.shape[0],self.cfg.model.generator.input_nz,1,1)) - self.fake = self.nets['netG'](self.z) + self.z = paddle.rand(shape=(self.real.shape[0], self.gen_cfg.input_nz, + 1, 1)) + self.fake = self.nets['netG'](self.z) # put items to visual dict self.visual_items['real'] = self.real @@ -91,10 +75,10 @@ class DCGANModel(BaseModel): """Calculate GAN loss for the discriminator""" # Fake; stop backprop to the generator by detaching fake pred_fake = self.nets['netD'](self.fake.detach()) - self.loss_D_fake = self.criterionGAN(pred_fake, False) + self.loss_D_fake = self.gan_criterion(pred_fake, False) pred_real = self.nets['netD'](self.real) - self.loss_D_real = self.criterionGAN(pred_real, True) + self.loss_D_real = self.gan_criterion(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 @@ -108,7 +92,7 @@ class DCGANModel(BaseModel): """Calculate GAN loss for the generator""" # G(A) should fake the discriminator pred_fake = self.nets['netD'](self.fake) - self.loss_G_GAN = self.criterionGAN(pred_fake, True) + self.loss_G_GAN = self.gan_criterion(pred_fake, True) # combine loss and calculate gradients self.loss_G = self.loss_G_GAN @@ -117,7 +101,7 @@ class DCGANModel(BaseModel): self.losses['G_adv_loss'] = self.loss_G_GAN - def optimize_parameters(self): + def train_iter(self, optimizers=None): # compute fake images: G(A) self.forward() @@ -133,4 +117,4 @@ class DCGANModel(BaseModel): self.set_requires_grad(self.nets['netG'], True) self.optimizers['optimizer_G'].clear_grad() self.backward_G() - self.optimizers['optimizer_G'].step() \ No newline at end of file + self.optimizers['optimizer_G'].step() diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index a2c991511c79e76e71535b8b8dd7c9b244d198da..41c23b5210ab737d6b31b0db2daec6d1636792b9 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .vgg_discriminator import VGGDiscriminator128 from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification from .discriminator_ugatit import UGATITDiscriminator from .dcdiscriminator import DCDiscriminator diff --git a/ppgan/models/discriminators/vgg_discriminator.py b/ppgan/models/discriminators/vgg_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..454e964446e67ee5de4fd388de3a0cb9fd6155ad --- /dev/null +++ b/ppgan/models/discriminators/vgg_discriminator.py @@ -0,0 +1,118 @@ +import paddle.nn as nn + +from .builder import DISCRIMINATORS + + +@DISCRIMINATORS.register() +class VGGDiscriminator128(nn.Layer): + """VGG style discriminator with input size 128 x 128. + + It is used to train SRGAN and ESRGAN. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. + Default: 64. + """ + def __init__(self, in_channels, num_feat, norm_layer='batch'): + super(VGGDiscriminator128, self).__init__() + + self.conv0_0 = nn.Conv2D(in_channels, num_feat, 3, 1, 1, bias_attr=True) + self.conv0_1 = nn.Conv2D(num_feat, num_feat, 4, 2, 1, bias_attr=False) + self.bn0_1 = nn.BatchNorm2D(num_feat) + + self.conv1_0 = nn.Conv2D(num_feat, + num_feat * 2, + 3, + 1, + 1, + bias_attr=False) + self.bn1_0 = nn.BatchNorm2D(num_feat * 2) + self.conv1_1 = nn.Conv2D(num_feat * 2, + num_feat * 2, + 4, + 2, + 1, + bias_attr=False) + self.bn1_1 = nn.BatchNorm2D(num_feat * 2) + + self.conv2_0 = nn.Conv2D(num_feat * 2, + num_feat * 4, + 3, + 1, + 1, + bias_attr=False) + self.bn2_0 = nn.BatchNorm2D(num_feat * 4) + self.conv2_1 = nn.Conv2D(num_feat * 4, + num_feat * 4, + 4, + 2, + 1, + bias_attr=False) + self.bn2_1 = nn.BatchNorm2D(num_feat * 4) + + self.conv3_0 = nn.Conv2D(num_feat * 4, + num_feat * 8, + 3, + 1, + 1, + bias_attr=False) + self.bn3_0 = nn.BatchNorm2D(num_feat * 8) + self.conv3_1 = nn.Conv2D(num_feat * 8, + num_feat * 8, + 4, + 2, + 1, + bias_attr=False) + self.bn3_1 = nn.BatchNorm2D(num_feat * 8) + + self.conv4_0 = nn.Conv2D(num_feat * 8, + num_feat * 8, + 3, + 1, + 1, + bias_attr=False) + self.bn4_0 = nn.BatchNorm2D(num_feat * 8) + self.conv4_1 = nn.Conv2D(num_feat * 8, + num_feat * 8, + 4, + 2, + 1, + bias_attr=False) + self.bn4_1 = nn.BatchNorm2D(num_feat * 8) + + self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x): + assert x.shape[2] == 128 and x.shape[3] == 128, ( + f'Input spatial size must be 128x128, ' + f'but received {x.shape}.') + + feat = self.lrelu(self.conv0_0(x)) + feat = self.lrelu(self.bn0_1( + self.conv0_1(feat))) # output spatial size: (64, 64) + + feat = self.lrelu(self.bn1_0(self.conv1_0(feat))) + feat = self.lrelu(self.bn1_1( + self.conv1_1(feat))) # output spatial size: (32, 32) + + feat = self.lrelu(self.bn2_0(self.conv2_0(feat))) + feat = self.lrelu(self.bn2_1( + self.conv2_1(feat))) # output spatial size: (16, 16) + + feat = self.lrelu(self.bn3_0(self.conv3_0(feat))) + feat = self.lrelu(self.bn3_1( + self.conv3_1(feat))) # output spatial size: (8, 8) + + feat = self.lrelu(self.bn4_0(self.conv4_0(feat))) + feat = self.lrelu(self.bn4_1( + self.conv4_1(feat))) # output spatial size: (4, 4) + + feat = feat.reshape([feat.shape[0], -1]) + feat = self.lrelu(self.linear1(feat)) + out = self.linear2(feat) + return out diff --git a/ppgan/models/esrgan_model.py b/ppgan/models/esrgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..09dc28318358c1fb40d4be486c1aac889002d9d6 --- /dev/null +++ b/ppgan/models/esrgan_model.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from .sr_model import BaseSRModel +from .builder import MODELS + +from .criterions import build_criterion + + +@MODELS.register() +class ESRGAN(BaseSRModel): + """ + This class implements the ESRGAN model. + + ESRGAN paper: https://arxiv.org/pdf/1809.00219.pdf + """ + def __init__(self, + generator, + discriminator=None, + pixel_criterion=None, + perceptual_criterion=None, + gan_criterion=None): + """Initialize the ESRGAN class. + + Args: + generator (dict): config of generator. + discriminator (dict): config of discriminator. + pixel_criterion (dict): config of pixel criterion. + perceptual_criterion (dict): config of perceptual criterion. + gan_criterion (dict): config of gan criterion. + """ + super(ESRGAN, self).__init__(generator) + + self.nets['generator'] = build_generator(generator) + + if discriminator: + self.nets['discriminator'] = build_discriminator(discriminator) + + if pixel_criterion: + self.pixel_criterion = build_criterion(pixel_criterion) + + if perceptual_criterion: + self.perceptual_criterion = build_criterion(perceptual_criterion) + + if gan_criterion: + self.gan_criterion = build_criterion(gan_criterion) + + def train_iter(self, optimizers=None): + self.set_requires_grad(self.nets['discriminator'], False) + optimizers['optimG'].clear_grad() + l_total = 0 + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + # pixel loss + if self.pixel_criterion: + l_pix = self.pixel_criterion(self.output, self.gt) + l_total += l_pix + self.losses['loss_pix'] = l_pix + if self.perceptual_criterion: + l_g_percep, l_g_style = self.perceptual_criterion( + self.output, self.gt) + # l_total += l_pix + if l_g_percep is not None: + l_total += l_g_percep + self.losses['loss_percep'] = l_g_percep + if l_g_style is not None: + l_total += l_g_style + self.losses['loss_style'] = l_g_style + + # gan loss (relativistic gan) + real_d_pred = self.nets['discriminator'](self.gt).detach() + fake_g_pred = self.nets['discriminator'](self.output) + l_g_real = self.gan_criterion(real_d_pred - paddle.mean(fake_g_pred), + False, + is_disc=False) + l_g_fake = self.gan_criterion(fake_g_pred - paddle.mean(real_d_pred), + True, + is_disc=False) + l_g_gan = (l_g_real + l_g_fake) / 2 + + l_total += l_g_gan + self.losses['l_g_gan'] = l_g_gan + + l_total.backward() + optimizers['optimG'].step() + + self.set_requires_grad(self.nets['discriminator'], True) + optimizers['optimD'].clear_grad() + # real + fake_d_pred = self.nets['discriminator'](self.output).detach() + real_d_pred = self.nets['discriminator'](self.gt) + l_d_real = self.gan_criterion( + real_d_pred - paddle.mean(fake_d_pred), True, is_disc=True) * 0.5 + + # fake + fake_d_pred = self.nets['discriminator'](self.output.detach()) + l_d_fake = self.gan_criterion( + fake_d_pred - paddle.mean(real_d_pred.detach()), + False, + is_disc=True) * 0.5 + + (l_d_real + l_d_fake).backward() + optimizers['optimD'].step() + + self.losses['l_d_real'] = l_d_real + self.losses['l_d_fake'] = l_d_fake + self.losses['out_d_real'] = paddle.mean(real_d_pred.detach()) + self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach()) diff --git a/ppgan/models/gan_model.py b/ppgan/models/gan_model.py index 2981d11ab327978bebf5dd1cadc0f999901fb088..e788d6fc40c05e4bbeb4fc0fedc53e762da8ca92 100644 --- a/ppgan/models/gan_model.py +++ b/ppgan/models/gan_model.py @@ -19,7 +19,7 @@ from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator -from .losses import GANLoss +from .criterions.gan_loss import GANLoss from ..solver import build_optimizer from ..modules.init import init_weights @@ -80,9 +80,11 @@ class GANModel(BaseModel): if not isinstance(input, dict): input = {'img': input} self.D_real_inputs = [paddle.to_tensor(input['img'])] - if 'class_id' in input: # n class input + if 'class_id' in input: # n class input self.n_class = self.nets['netG'].n_class - self.D_real_inputs += [paddle.to_tensor(input['class_id'], dtype='int64')] + self.D_real_inputs += [ + paddle.to_tensor(input['class_id'], dtype='int64') + ] else: self.n_class = 0 @@ -97,15 +99,18 @@ class GANModel(BaseModel): rows_num = (batch_size - 1) // self.samples_every_row + 1 class_ids = paddle.randint(0, self.n_class, [rows_num, 1]) class_ids = class_ids.tile([1, self.samples_every_row]) - class_ids = class_ids.reshape([-1,])[:batch_size].detach() + class_ids = class_ids.reshape([ + -1, + ])[:batch_size].detach() self.G_fixed_inputs[1] = class_ids.detach() def forward(self): """Run forward pass; called by both functions and .""" - self.fake_imgs = self.nets['netG'](*self.G_inputs) # G(img, class_id) + self.fake_imgs = self.nets['netG'](*self.G_inputs) # G(img, class_id) # put items to visual dict - self.visual_items['fake_imgs'] = make_grid(self.fake_imgs, self.samples_every_row).detach() + self.visual_items['fake_imgs'] = make_grid( + self.fake_imgs, self.samples_every_row).detach() def backward_D(self): """Calculate GAN loss for the discriminator""" @@ -118,7 +123,8 @@ class GANModel(BaseModel): pred_fake = self.nets['netD'](*self.D_fake_inputs) # Real real_imgs = self.D_real_inputs[0] - self.visual_items['real_imgs'] = make_grid(real_imgs, self.samples_every_row).detach() + self.visual_items['real_imgs'] = make_grid( + real_imgs, self.samples_every_row).detach() pred_real = self.nets['netD'](*self.D_real_inputs) self.loss_D_fake = self.criterionGAN(pred_fake, False, True) @@ -126,7 +132,8 @@ class GANModel(BaseModel): # combine loss and calculate gradients if self.cfg.model.gan_mode in ['vanilla', 'lsgan']: - self.loss_D = self.loss_D + (self.loss_D_fake + self.loss_D_real) * 0.5 + self.loss_D = self.loss_D + (self.loss_D_fake + + self.loss_D_real) * 0.5 else: self.loss_D = self.loss_D + self.loss_D_fake + self.loss_D_real @@ -179,7 +186,7 @@ class GANModel(BaseModel): if self.step % self.visual_interval == 0: with paddle.no_grad(): self.visual_items['fixed_generated_imgs'] = make_grid( - self.nets['netG'](*self.G_fixed_inputs), self.samples_every_row - ) + self.nets['netG'](*self.G_fixed_inputs), + self.samples_every_row) self.step += 1 diff --git a/ppgan/models/makeup_model.py b/ppgan/models/makeup_model.py index 1031fd52adf3a43e23122b15766bc7f18947aa31..947191b053690115968aa768b8a5815076088f81 100644 --- a/ppgan/models/makeup_model.py +++ b/ppgan/models/makeup_model.py @@ -15,8 +15,7 @@ import os import numpy as np import paddle -import paddle.nn as nn -import paddle.nn.functional as F + from paddle.vision.models import vgg16 from paddle.utils.download import get_path_from_url from .base_model import BaseModel @@ -24,12 +23,10 @@ from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator -from .losses import GANLoss +from .criterions import build_criterion from ..modules.init import init_weights -from ..solver import build_optimizer from ..utils.image_pool import ImagePool from ..utils.preprocess import * -from ..datasets.makeup_dataset import MakeupDataset VGGFACE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/vggface.pdparams' @@ -39,18 +36,32 @@ class MakeupModel(BaseModel): """ PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf """ - def __init__(self, cfg): + def __init__(self, + generator, + discriminator=None, + cycle_criterion=None, + idt_criterion=None, + gan_criterion=None, + l1_criterion=None, + l2_criterion=None, + pool_size=50, + direction='a2b', + lambda_a=10., + lambda_b=10., + is_train=True): """Initialize the PSGAN class. Parameters: cfg (dict)-- config of model. """ - super(MakeupModel, self).__init__(cfg) - + super(MakeupModel, self).__init__() + self.lambda_a = lambda_a + self.lambda_b = lambda_b + self.is_train = is_train # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) - self.nets['netG'] = build_generator(cfg.model.generator) + self.nets['netG'] = build_generator(generator) init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0) if self.is_train: # define discriminators @@ -61,46 +72,33 @@ class MakeupModel(BaseModel): param = paddle.load(vgg_weight_path) vgg.load_dict(param) - self.nets['netD_A'] = build_discriminator(cfg.model.discriminator) - self.nets['netD_B'] = build_discriminator(cfg.model.discriminator) + self.nets['netD_A'] = build_discriminator(discriminator) + self.nets['netD_B'] = build_discriminator(discriminator) init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0) init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0) - self.fake_A_pool = ImagePool( - cfg.dataset.train.pool_size - ) # create image buffer to store previously generated images - self.fake_B_pool = ImagePool( - cfg.dataset.train.pool_size - ) # create image buffer to store previously generated images + # create image buffer to store previously generated images + self.fake_A_pool = ImagePool(pool_size) + self.fake_B_pool = ImagePool(pool_size) + # define loss functions - self.criterionGAN = GANLoss( - cfg.model.gan_mode) #.to(self.device) # define GAN loss. - self.criterionCycle = paddle.nn.L1Loss() - self.criterionIdt = paddle.nn.L1Loss() - self.criterionL1 = paddle.nn.L1Loss() - self.criterionL2 = paddle.nn.MSELoss() - - self.build_lr_scheduler() - self.optimizers['optimizer_G'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netG'].parameters()) - self.optimizers['optimizer_DA'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netD_A'].parameters()) - self.optimizers['optimizer_DB'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netD_B'].parameters()) - - def set_input(self, input): + if gan_criterion: + self.gan_criterion = build_criterion(gan_criterion) + if cycle_criterion: + self.cycle_criterion = build_criterion(cycle_criterion) + if idt_criterion: + self.idt_criterion = build_criterion(idt_criterion) + if l1_criterion: + self.l1_criterion = build_criterion(l1_criterion) + if l2_criterion: + self.l2_criterion = build_criterion(l2_criterion) + + def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. - Parameters: + Args: input (dict): include the data itself and its metadata information. - The option 'direction' can be used to swap domain A and domain B. """ self.real_A = paddle.to_tensor(input['image_A']) self.real_B = paddle.to_tensor(input['image_B']) @@ -143,24 +141,6 @@ class MakeupModel(BaseModel): self.visual_items['fake_A'] = self.fake_A self.visual_items['rec_B'] = self.rec_B - def forward_test(self, input): - ''' - not implement now - ''' - return self.nets['netG'](input['image_A'], input['image_B'], - input['P_A'], input['P_B'], - input['consis_mask'], input['mask_A_aug'], - input['mask_B_aug']) - - def test(self, input): - """Forward function used in test time. - - This function wraps function in no_grad() so we don't save intermediate steps for backprop - It also calls to produce additional visualization results - """ - with paddle.no_grad(): - return self.forward_test(input) - def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator @@ -174,10 +154,10 @@ class MakeupModel(BaseModel): """ # Real pred_real = netD(real) - loss_D_real = self.criterionGAN(pred_real, True) + loss_D_real = self.gan_criterion(pred_real, True) # Fake pred_fake = netD(fake.detach()) - loss_D_fake = self.criterionGAN(pred_fake, False) + loss_D_fake = self.gan_criterion(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() @@ -200,24 +180,24 @@ class MakeupModel(BaseModel): def backward_G(self): """Calculate the loss for generators G_A and G_B""" - lambda_idt = self.cfg.lambda_identity - lambda_A = self.cfg.lambda_A - lambda_B = self.cfg.lambda_B + lambda_A = self.lambda_a + lambda_B = self.lambda_b lambda_vgg = 5e-3 + # Identity loss - if lambda_idt > 0: + if self.idt_criterion: self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A, self.P_A, self.P_A, self.c_m_idt_a, self.mask_A_aug, self.mask_B_aug) # G_A(A) - self.loss_idt_A = self.criterionIdt( - self.idt_A, self.real_A) * lambda_A * lambda_idt + self.loss_idt_A = self.idt_criterion(self.idt_A, + self.real_A) * lambda_A self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B, self.P_B, self.P_B, self.c_m_idt_b, self.mask_A_aug, self.mask_B_aug) # G_A(A) - self.loss_idt_B = self.criterionIdt( - self.idt_B, self.real_B) * lambda_B * lambda_idt + self.loss_idt_B = self.idt_criterion(self.idt_B, + self.real_B) * lambda_B # visual self.visual_items['idt_A'] = self.idt_A @@ -227,17 +207,17 @@ class MakeupModel(BaseModel): self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) - self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_A), - True) + self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_A), + True) # GAN loss D_B(G_B(B)) - self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_B), - True) + self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_B), + True) # Forward cycle loss || G_B(G_A(A)) - A|| - self.loss_cycle_A = self.criterionCycle(self.rec_A, - self.real_A) * lambda_A + self.loss_cycle_A = self.cycle_criterion(self.rec_A, + self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| - self.loss_cycle_B = self.criterionCycle(self.rec_B, - self.real_B) * lambda_B + self.loss_cycle_B = self.cycle_criterion(self.rec_B, + self.real_B) * lambda_B self.losses['G_A_adv_loss'] = self.loss_G_A self.losses['G_B_adv_loss'] = self.loss_G_B @@ -270,8 +250,10 @@ class MakeupModel(BaseModel): fake_match_lip_B = fake_match_lip_B.unsqueeze(0) fake_A_lip_masked = fake_A * mask_A_lip fake_B_lip_masked = fake_B * mask_B_lip - g_A_lip_loss_his = self.criterionL1(fake_A_lip_masked, fake_match_lip_A) - g_B_lip_loss_his = self.criterionL1(fake_B_lip_masked, fake_match_lip_B) + g_A_lip_loss_his = self.l1_criterion(fake_A_lip_masked, + fake_match_lip_A) + g_B_lip_loss_his = self.l1_criterion(fake_B_lip_masked, + fake_match_lip_B) #skin mask_A_skin = self.mask_A_aug[:, 1].unsqueeze(1) @@ -294,10 +276,10 @@ class MakeupModel(BaseModel): fake_match_skin_B = fake_match_skin_B.unsqueeze(0) fake_A_skin_masked = fake_A * mask_A_skin fake_B_skin_masked = fake_B * mask_B_skin - g_A_skin_loss_his = self.criterionL1(fake_A_skin_masked, - fake_match_skin_A) - g_B_skin_loss_his = self.criterionL1(fake_B_skin_masked, - fake_match_skin_B) + g_A_skin_loss_his = self.l1_criterion(fake_A_skin_masked, + fake_match_skin_A) + g_B_skin_loss_his = self.l1_criterion(fake_B_skin_masked, + fake_match_skin_B) #eye mask_A_eye = self.mask_A_aug[:, 2].unsqueeze(1) @@ -320,8 +302,10 @@ class MakeupModel(BaseModel): fake_match_eye_B = fake_match_eye_B.unsqueeze(0) fake_A_eye_masked = fake_A * mask_A_eye fake_B_eye_masked = fake_B * mask_B_eye - g_A_eye_loss_his = self.criterionL1(fake_A_eye_masked, fake_match_eye_A) - g_B_eye_loss_his = self.criterionL1(fake_B_eye_masked, fake_match_eye_B) + g_A_eye_loss_his = self.l1_criterion(fake_A_eye_masked, + fake_match_eye_A) + g_B_eye_loss_his = self.l1_criterion(fake_B_eye_masked, + fake_match_eye_B) self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his + g_A_skin_loss_his * 0.1) * 0.1 @@ -335,14 +319,14 @@ class MakeupModel(BaseModel): vgg_s = self.vgg(self.real_A) vgg_s.stop_gradient = True vgg_fake_A = self.vgg(self.fake_A) - self.loss_A_vgg = self.criterionL2(vgg_fake_A, - vgg_s) * lambda_A * lambda_vgg + self.loss_A_vgg = self.l2_criterion(vgg_fake_A, + vgg_s) * lambda_A * lambda_vgg vgg_r = self.vgg(self.real_B) vgg_r.stop_gradient = True vgg_fake_B = self.vgg(self.fake_B) - self.loss_B_vgg = self.criterionL2(vgg_fake_B, - vgg_r) * lambda_B * lambda_vgg + self.loss_B_vgg = self.l2_criterion(vgg_fake_B, + vgg_r) * lambda_B * lambda_vgg self.loss_rec = (self.loss_cycle_A * 0.2 + self.loss_cycle_B * 0.2 + self.loss_A_vgg + self.loss_B_vgg) * 0.5 @@ -359,15 +343,14 @@ class MakeupModel(BaseModel): (self.mask_A == 10), dtype='float32') + paddle.cast( (self.mask_A == 8), dtype='float32') mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1) - self.loss_G_bg_consis = self.criterionL1( + self.loss_G_bg_consis = self.l1_criterion( self.real_A * mask_A_consis, self.fake_A * mask_A_consis) * 0.1 # combined loss and calculate gradients - self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_rec + self.loss_idt + self.loss_G_A_his + self.loss_G_B_his + self.loss_G_bg_consis self.loss_G.backward() - def optimize_parameters(self): + def train_iter(self, optimizers=None): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. @@ -392,5 +375,4 @@ class MakeupModel(BaseModel): self.backward_D_B() # calculate graidents for D_B self.optimizers['optimizer_DB'].minimize( self.loss_D_B) #step() # update D_A and D_B's weights - self.optimizers['optimizer_DB'].clear_gradients( - ) #zero_grad() # set D_A and D_B's gradients to zero + self.optimizers['optimizer_DB'].clear_gradients() diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index e44bba0d9df566a8cb8ed0c97ae9eda05df788bf..bfb3d0b849933909fd84a851aa456cc50d45b83e 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -18,7 +18,7 @@ from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator -from .losses import GANLoss +from .criterions import build_criterion from ..solver import build_optimizer from ..modules.init import init_weights @@ -29,63 +29,57 @@ from ..utils.image_pool import ImagePool class Pix2PixModel(BaseModel): """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. - The model training requires 'paired' dataset. - By default, it uses a '--netG unet256' U-Net generator, - a '--netD basic' discriminator (from PatchGAN), - and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). - pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf """ - def __init__(self, cfg): + def __init__(self, + generator, + discriminator=None, + pixel_criterion=None, + gan_criterion=None, + direction='a2b'): """Initialize the pix2pix class. - Parameters: - opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict + Args: + generator (dict): config of generator. + discriminator (dict): config of discriminator. + pixel_criterion (dict): config of pixel criterion. + gan_criterion (dict): config of gan criterion. """ - super(Pix2PixModel, self).__init__(cfg) + super(Pix2PixModel, self).__init__() + + self.direction = direction # define networks (both generator and discriminator) - self.nets['netG'] = build_generator(cfg.model.generator) + self.nets['netG'] = build_generator(generator) init_weights(self.nets['netG']) # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc - if self.is_train: - self.nets['netD'] = build_discriminator(cfg.model.discriminator) + if discriminator: + self.nets['netD'] = build_discriminator(discriminator) init_weights(self.nets['netD']) - if self.is_train: - self.losses = {} - # define loss functions - self.criterionGAN = GANLoss(cfg.model.gan_mode) - self.criterionL1 = paddle.nn.L1Loss() - - # build optimizers - self.build_lr_scheduler() - self.optimizers['optimizer_G'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netG'].parameters()) - self.optimizers['optimizer_D'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netD'].parameters()) - - def set_input(self, input): + if pixel_criterion: + self.pixel_criterion = build_criterion(pixel_criterion) + + if gan_criterion: + self.gan_criterion = build_criterion(gan_criterion) + + def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. - Parameters: + Args: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap images in domain A and domain B. """ - AtoB = self.cfg.dataset.train.direction == 'AtoB' + AtoB = self.direction == 'AtoB' self.real_A = paddle.fluid.dygraph.to_variable( input['A' if AtoB else 'B']) self.real_B = paddle.fluid.dygraph.to_variable( input['B' if AtoB else 'A']) - self.image_paths = input['A_paths' if AtoB else 'B_paths'] + self.image_paths = input['A_path' if AtoB else 'B_path'] def forward(self): """Run forward pass; called by both functions and .""" @@ -102,11 +96,11 @@ class Pix2PixModel(BaseModel): # use conditional GANs; we need to feed both input and output to the discriminator fake_AB = paddle.concat((self.real_A, self.fake_B), 1) pred_fake = self.nets['netD'](fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(pred_fake, False) + self.loss_D_fake = self.gan_criterion(pred_fake, False) # Real real_AB = paddle.concat((self.real_A, self.real_B), 1) pred_real = self.nets['netD'](real_AB) - self.loss_D_real = self.criterionGAN(pred_real, True) + self.loss_D_real = self.gan_criterion(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 @@ -120,10 +114,9 @@ class Pix2PixModel(BaseModel): # First, G(A) should fake the discriminator fake_AB = paddle.concat((self.real_A, self.fake_B), 1) pred_fake = self.nets['netD'](fake_AB) - self.loss_G_GAN = self.criterionGAN(pred_fake, True) + self.loss_G_GAN = self.gan_criterion(pred_fake, True) # Second, G(A) = B - self.loss_G_L1 = self.criterionL1(self.fake_B, - self.real_B) * self.cfg.lambda_L1 + self.loss_G_L1 = self.pixel_criterion(self.fake_B, self.real_B) # combine loss and calculate gradients self.loss_G = self.loss_G_GAN + self.loss_G_L1 @@ -133,18 +126,18 @@ class Pix2PixModel(BaseModel): self.losses['G_adv_loss'] = self.loss_G_GAN self.losses['G_L1_loss'] = self.loss_G_L1 - def optimize_parameters(self): + def train_iter(self, optimizers=None): # compute fake images: G(A) self.forward() # update D self.set_requires_grad(self.nets['netD'], True) - self.optimizers['optimizer_D'].clear_grad() + optimizers['optimD'].clear_grad() self.backward_D() - self.optimizers['optimizer_D'].step() + optimizers['optimD'].step() # update G self.set_requires_grad(self.nets['netD'], False) - self.optimizers['optimizer_G'].clear_grad() + optimizers['optimG'].clear_grad() self.backward_G() - self.optimizers['optimizer_G'].step() + optimizers['optimG'].step() diff --git a/ppgan/models/sr_model.py b/ppgan/models/sr_model.py index 102ed9de790522bf4932c270368119d5e3cbbcae..565dc649f6a49d67e67b0ae6fdc5a25f25cf9d2e 100644 --- a/ppgan/models/sr_model.py +++ b/ppgan/models/sr_model.py @@ -12,75 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict import paddle import paddle.nn as nn from .generators.builder import build_generator -from .discriminators.builder import build_discriminator -from ..solver import build_optimizer +from .criterions.builder import build_criterion from .base_model import BaseModel -from .losses import GANLoss -from .builder import MODELS - -import importlib -from collections import OrderedDict -from copy import deepcopy -from os import path as osp from .builder import MODELS +from ..utils.visual import tensor2img @MODELS.register() -class SRModel(BaseModel): - """Base SR model for single image super-resolution.""" - def __init__(self, cfg): - super(SRModel, self).__init__(cfg) - - self.model_names = ['G'] - - self.netG = build_generator(cfg.model.generator) - self.visual_names = ['lq', 'output', 'gt'] - - self.loss_names = ['l_total'] +class BaseSRModel(BaseModel): + """Base SR model for single image super-resolution. + """ + def __init__(self, generator, pixel_criterion=None): + """ + Args: + generator (dict): config of generator. + pixel_criterion (dict): config of pixel criterion. + """ + super(BaseSRModel, self).__init__() - self.optimizers = [] - if self.is_train: - self.criterionL1 = paddle.nn.L1Loss() + self.nets['generator'] = build_generator(generator) - self.build_lr_scheduler() - self.optimizer_G = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.netG.parameters()) - self.optimizers.append(self.optimizer_G) + if pixel_criterion: + self.pixel_criterion = build_criterion(pixel_criterion) - def set_input(self, input): - self.lq = paddle.to_tensor(input['lq']) + def setup_input(self, input): + self.lq = paddle.fluid.dygraph.to_variable(input['lq']) + self.visual_items['lq'] = self.lq if 'gt' in input: - self.gt = paddle.to_tensor(input['gt']) + self.gt = paddle.fluid.dygraph.to_variable(input['gt']) + self.visual_items['gt'] = self.gt self.image_paths = input['lq_path'] def forward(self): pass - def test(self): - """Forward function used in test time. - """ - with paddle.no_grad(): - self.output = self.netG(self.lq) - - def optimize_parameters(self): - self.optimizer_G.clear_grad() - self.output = self.netG(self.lq) + def train_iter(self, optims=None): + optims['optim'].clear_grad() - l_total = 0 - loss_dict = OrderedDict() + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output # pixel loss - if self.criterionL1: - l_pix = self.criterionL1(self.output, self.gt) - l_total += l_pix - loss_dict['l_pix'] = l_pix + loss_pixel = self.pixel_criterion(self.output, self.gt) + self.losses['loss_pixel'] = loss_pixel + + loss_pixel.backward() + optims['optim'].step() - l_total.backward() - self.loss_l_total = l_total - self.optimizer_G.step() + def test_iter(self, metrics=None): + self.nets['generator'].eval() + with paddle.no_grad(): + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + self.nets['generator'].train() + + out_img = [] + gt_img = [] + for out_tensor, gt_tensor in zip(self.output, self.gt): + out_img.append(tensor2img(out_tensor, (0., 1.))) + gt_img.append(tensor2img(gt_tensor, (0., 1.))) + + if metrics is not None: + for metric in metrics.values(): + metric.update(out_img, gt_img) diff --git a/ppgan/models/srgan_model.py b/ppgan/models/srgan_model.py deleted file mode 100644 index 0dc8a8820bd020369a7b5b36a5067429ea8fca9d..0000000000000000000000000000000000000000 --- a/ppgan/models/srgan_model.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -import paddle -import paddle.nn as nn - -from .generators.builder import build_generator -from .base_model import BaseModel -from .losses import GANLoss -from .builder import MODELS - - -@MODELS.register() -class SRGANModel(BaseModel): - def __init__(self, cfg): - super(SRGANModel, self).__init__(cfg) - - # define networks - self.model_names = ['G'] - - self.netG = build_generator(cfg.model.generator) - self.visual_names = ['LQ', 'GT', 'fake_H'] - - # TODO: support srgan train. - if False: - # self.netD = build_discriminator(cfg.model.discriminator) - self.netG.train() - # self.netD.train() - - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input (dict): include the data itself and its metadata information. - - The option 'direction' can be used to swap images in domain A and domain B. - """ - - # AtoB = self.opt.dataset.train.direction == 'AtoB' - if 'A' in input: - self.LQ = paddle.to_tensor(input['A']) - if 'B' in input: - self.GT = paddle.to_tensor(input['B']) - if 'A_paths' in input: - self.image_paths = input['A_paths'] - - def forward(self): - self.fake_H = self.netG(self.LQ) - - def optimize_parameters(self, step): - pass diff --git a/ppgan/models/ugatit_model.py b/ppgan/models/ugatit_model.py index 73587e2cbd767856d260e2088b92359c72a20dd1..8488b1ba3f6143866cdc7162cd41945d8e27be35 100644 --- a/ppgan/models/ugatit_model.py +++ b/ppgan/models/ugatit_model.py @@ -19,7 +19,7 @@ from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator -from .losses import GANLoss +from .criterions import build_criterion from ..solver import build_optimizer from ..modules.nn import RhoClipper @@ -34,54 +34,58 @@ class UGATITModel(BaseModel): UGATIT paper: https://arxiv.org/pdf/1907.10830.pdf """ - def __init__(self, cfg): + def __init__(self, + generator, + discriminator_g=None, + discriminator_l=None, + l1_criterion=None, + mse_criterion=None, + bce_criterion=None, + direction='a2b', + adv_weight=1.0, + cycle_weight=10.0, + identity_weight=10.0, + cam_weight=1000.0): """Initialize the CycleGAN class. Parameters: opt (config)-- stores all the experiment flags; needs to be a subclass of Dict """ - super(UGATITModel, self).__init__(cfg) - + super(UGATITModel, self).__init__() + self.adv_weight = adv_weight + self.cycle_weight = cycle_weight + self.identity_weight = identity_weight + self.cam_weight = cam_weight + self.direction = direction # define networks (both Generators and discriminators) # The naming is different from those used in the paper. - self.nets['genA2B'] = build_generator(cfg.model.generator) - self.nets['genB2A'] = build_generator(cfg.model.generator) + self.nets['genA2B'] = build_generator(generator) + self.nets['genB2A'] = build_generator(generator) init_weights(self.nets['genA2B']) init_weights(self.nets['genB2A']) - if self.is_train: + if discriminator_g and discriminator_l: # define discriminators - self.nets['disGA'] = build_discriminator(cfg.model.discriminator_g) - self.nets['disGB'] = build_discriminator(cfg.model.discriminator_g) - self.nets['disLA'] = build_discriminator(cfg.model.discriminator_l) - self.nets['disLB'] = build_discriminator(cfg.model.discriminator_l) + self.nets['disGA'] = build_discriminator(discriminator_g) + self.nets['disGB'] = build_discriminator(discriminator_g) + self.nets['disLA'] = build_discriminator(discriminator_l) + self.nets['disLB'] = build_discriminator(discriminator_l) init_weights(self.nets['disGA']) init_weights(self.nets['disGB']) init_weights(self.nets['disLA']) init_weights(self.nets['disLB']) - if self.is_train: - # define loss functions - self.BCE_loss = nn.BCEWithLogitsLoss() - self.L1_loss = nn.L1Loss() - self.MSE_loss = nn.MSELoss() - - self.build_lr_scheduler() - self.optimizers['optimizer_G'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['genA2B'].parameters() + - self.nets['genB2A'].parameters()) - self.optimizers['optimizer_D'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['disGA'].parameters() + - self.nets['disGB'].parameters() + - self.nets['disLA'].parameters() + - self.nets['disLB'].parameters()) - self.Rho_clipper = RhoClipper(0, 1) - - def set_input(self, input): + # define loss functions + if l1_criterion: + self.L1_loss = build_criterion(l1_criterion) + if bce_criterion: + self.BCE_loss = build_criterion(bce_criterion) + if mse_criterion: + self.MSE_loss = build_criterion(mse_criterion) + + self.Rho_clipper = RhoClipper(0, 1) + + def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Args: @@ -89,8 +93,7 @@ class UGATITModel(BaseModel): The option 'direction' can be used to swap domain A and domain B. """ - mode = 'train' if self.is_train else 'test' - AtoB = self.cfg.dataset[mode].direction == 'AtoB' + AtoB = self.direction == 'a2b' if AtoB: if 'A' in input: @@ -109,7 +112,7 @@ class UGATITModel(BaseModel): self.image_paths = input['B_paths'] def forward(self): - """Run forward pass; called by both functions and .""" + """Run forward pass; called by both functions and .""" if hasattr(self, 'real_A'): self.fake_A2B, _, _ = self.nets['genA2B'](self.real_A) @@ -124,7 +127,7 @@ class UGATITModel(BaseModel): self.visual_items['real_B'] = self.real_B self.visual_items['fake_B2A'] = self.fake_B2A - def test(self): + def test_iter(self, metrics=None): """Forward function used in test time. This function wraps function in no_grad() so we don't save intermediate steps for backprop @@ -139,7 +142,7 @@ class UGATITModel(BaseModel): self.nets['genA2B'].train() self.nets['genB2A'].train() - def optimize_parameters(self): + def train_iter(self, optimizers=None): """Calculate losses, gradients, and update network weights; called in every training iteration""" def _criterion(loss_func, logit, is_real): if is_real: @@ -153,7 +156,7 @@ class UGATITModel(BaseModel): self.forward() # update D - self.optimizers['optimizer_D'].clear_grad() + optimizers['optimD'].clear_grad() real_GA_logit, real_GA_cam_logit, _ = self.nets['disGA'](self.real_A) real_LA_logit, real_LA_cam_logit, _ = self.nets['disLA'](self.real_A) real_GB_logit, real_GB_cam_logit, _ = self.nets['disGB'](self.real_B) @@ -196,17 +199,17 @@ class UGATITModel(BaseModel): self.MSE_loss, real_LB_cam_logit, True) + _criterion( self.MSE_loss, fake_LB_cam_logit, False) - D_loss_A = self.cfg.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + - D_ad_loss_LA + D_ad_cam_loss_LA) - D_loss_B = self.cfg.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + - D_ad_loss_LB + D_ad_cam_loss_LB) + D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + + D_ad_loss_LA + D_ad_cam_loss_LA) + D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() - self.optimizers['optimizer_D'].step() + optimizers['optimD'].step() # update G - self.optimizers['optimizer_G'].clear_grad() + optimizers['optimG'].clear_grad() fake_A2B, fake_A2B_cam_logit, _ = self.nets['genA2B'](self.real_A) fake_B2A, fake_B2A_cam_logit, _ = self.nets['genB2A'](self.real_B) @@ -245,16 +248,16 @@ class UGATITModel(BaseModel): fake_A2B_cam_logit, True) + _criterion( self.BCE_loss, fake_B2B_cam_logit, False) - G_loss_A = self.cfg.adv_weight * ( + G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA - ) + self.cfg.cycle_weight * G_recon_loss_A + self.cfg.identity_weight * G_identity_loss_A + self.cfg.cam_weight * G_cam_loss_A - G_loss_B = self.cfg.adv_weight * ( + ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A + G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB - ) + self.cfg.cycle_weight * G_recon_loss_B + self.cfg.identity_weight * G_identity_loss_B + self.cfg.cam_weight * G_cam_loss_B + ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() - self.optimizers['optimizer_G'].step() + optimizers['optimG'].step() # clip parameter of AdaILN and ILN, applied after optimizer step self.nets['genA2B'].apply(self.Rho_clipper) diff --git a/ppgan/modules/init.py b/ppgan/modules/init.py index 836537a26a7ea5473ed88469e2b1e6e0a0bf8eb8..4a4cc16129149bfc35624c6384c0ee6f808fd3c2 100644 --- a/ppgan/modules/init.py +++ b/ppgan/modules/init.py @@ -19,8 +19,6 @@ import paddle from ..utils.logger import get_logger -logger = get_logger('init') - def _calculate_fan_in_and_fan_out(tensor): dimensions = len(tensor.shape) @@ -313,5 +311,6 @@ def init_weights(net, init_type='normal', init_gain=0.02): normal_(m.weight, 1.0, init_gain) constant_(m.bias, 0.0) + logger = get_logger() logger.debug('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function diff --git a/ppgan/solver/__init__.py b/ppgan/solver/__init__.py index 57021870b340283a36eb3238bb42da598c3181e9..1b4d1fc7b586773978d80c0a397592b0ca7af5de 100644 --- a/ppgan/solver/__init__.py +++ b/ppgan/solver/__init__.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .optimizer import build_optimizer +from .lr_scheduler import CosineAnnealingRestartLR, LinearDecay +from .optimizer import * +from .builder import build_lr_scheduler +from .builder import build_optimizer diff --git a/ppgan/datasets/transforms/makeup_transforms.py b/ppgan/solver/builder.py similarity index 53% rename from ppgan/datasets/transforms/makeup_transforms.py rename to ppgan/solver/builder.py index ba2265a55db6efea7545ef4b0febb1894219a413..854e89bef25d516af802e7675500e7dae7c0eee7 100644 --- a/ppgan/datasets/transforms/makeup_transforms.py +++ b/ppgan/solver/builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.vision.transforms as T -import cv2 +from ..utils.registry import Registry +LRSCHEDULERS = Registry("LRSCHEDULER") +OPTIMIZERS = Registry("OPTIMIZER") -def get_makeup_transform(cfg, pic="image"): - if pic == "image": - transform = T.Compose([ - T.Resize(size=cfg.trans_size), - T.Transpose(), - ]) - else: - transform = T.Resize(size=cfg.trans_size, - interpolation=cv2.INTER_NEAREST) - return transform +def build_lr_scheduler(cfg): + cfg_ = cfg.copy() + name = cfg_.pop('name') + return LRSCHEDULERS.get(name)(**cfg_) + + +def build_optimizer(cfg, lr_scheduler, parameters=None): + cfg_ = cfg.copy() + name = cfg_.pop('name') + return OPTIMIZERS.get(name)(lr_scheduler, parameters=parameters, **cfg_) diff --git a/ppgan/solver/lr_scheduler.py b/ppgan/solver/lr_scheduler.py index e21943dfab43d66b8cd63c388cea2d8e8a553cd8..aa7cc3de1ddeb95434eafd26b1c2bb8c7c8b8e0b 100644 --- a/ppgan/solver/lr_scheduler.py +++ b/ppgan/solver/lr_scheduler.py @@ -12,22 +12,93 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import paddle +from paddle.optimizer.lr import LRScheduler, MultiStepDecay, LambdaDecay +from .builder import LRSCHEDULERS -def build_lr_scheduler(cfg): - name = cfg.pop('name') +LRSCHEDULERS.register(MultiStepDecay) - # TODO: add more learning rate scheduler - if name == 'linear': +@LRSCHEDULERS.register() +class LinearDecay(LambdaDecay): + def __init__(self, learning_rate, start_epoch, decay_epochs, + iters_per_epoch): def lambda_rule(epoch): - lr_l = 1.0 - max( - 0, epoch + 1 - cfg.start_epoch) / float(cfg.decay_epochs + 1) + epoch = epoch // iters_per_epoch + lr_l = 1.0 - max(0, + epoch + 1 - start_epoch) / float(decay_epochs + 1) return lr_l - scheduler = paddle.optimizer.lr.LambdaDecay(cfg.learning_rate, - lr_lambda=lambda_rule) - return scheduler - else: - raise NotImplementedError + super().__init__(learning_rate, lambda_rule) + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +@LRSCHEDULERS.register() +class CosineAnnealingRestartLR(LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + learning_rate (float|paddle.nn.optimizer): PaddlePaddle optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + def __init__(self, + learning_rate, + periods, + restart_weights=[1], + eta_min=0, + last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + self.cumulative_period = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + super(CosineAnnealingRestartLR, self).__init__(learning_rate, + last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + lr = self.eta_min + current_weight * 0.5 * ( + self.base_lr - self.eta_min) * (1 + math.cos(math.pi * ( + (self.last_epoch - nearest_restart) / current_period))) + return lr diff --git a/ppgan/solver/optimizer.py b/ppgan/solver/optimizer.py index 389b0eb981a4400ff399d3756912a15d80ccb5fd..36345b54cffdcd89e2349ac13a05a6b24f86f1fd 100644 --- a/ppgan/solver/optimizer.py +++ b/ppgan/solver/optimizer.py @@ -15,14 +15,9 @@ import copy import paddle -from .lr_scheduler import build_lr_scheduler +from .builder import OPTIMIZERS - -def build_optimizer(cfg, lr_scheduler, parameter_list=None): - cfg_copy = copy.deepcopy(cfg) - - opt_name = cfg_copy.pop('name') - - return getattr(paddle.optimizer, opt_name)(lr_scheduler, - parameters=parameter_list, - **cfg_copy) +OPTIMIZERS.register(paddle.optimizer.Adam) +OPTIMIZERS.register(paddle.optimizer.SGD) +OPTIMIZERS.register(paddle.optimizer.Momentum) +OPTIMIZERS.register(paddle.optimizer.RMSProp) diff --git a/ppgan/utils/registry.py b/ppgan/utils/registry.py index 979572d4850903ae1de965415e8fbf5c3ff3c287..3287854d5ed0b131cae5ab23569fa144faeba943 100644 --- a/ppgan/utils/registry.py +++ b/ppgan/utils/registry.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import traceback + class Registry(object): """ @@ -72,3 +75,53 @@ class Registry(object): name, self._name)) return ret + + +def build_from_config(cfg, registry, default_args=None): + """Build a class from config dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "name". + registry (ppgan.utils.Registry): The registry to search the name from. + default_args (dict, optional): Default initialization arguments. + + Returns: + class: The constructed class. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'name' not in cfg: + if default_args is None or 'name' not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "name", ' + f'but got {cfg}\n{default_args}') + if not isinstance(registry, Registry): + raise TypeError('registry must be an mmcv.Registry object, ' + f'but got {type(registry)}') + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + cls_name = args.pop('name') + if isinstance(cls_name, str): + obj_cls = registry.get(cls_name) + elif inspect.isclass(cls_name): + obj_cls = obj_cls + else: + raise TypeError( + f'name must be a str or valid name, but got {type(cls_name)}') + + try: + instance = obj_cls(**args) + except Exception as e: + stack_info = traceback.format_exc() + print("Fail to initial class [{}] with error: " + "{} and stack:\n{}".format(cls_name, e, str(stack_info))) + raise e + return instance diff --git a/ppgan/utils/visual.py b/ppgan/utils/visual.py index 49d9c303dfae9a77806091ab2e6fefd39a453df3..dccb7a9d62f8fc42432fe27be451662aa4490ebd 100644 --- a/ppgan/utils/visual.py +++ b/ppgan/utils/visual.py @@ -36,8 +36,10 @@ def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False): images separately rather than the (min, max) over all images. Default: ``False``. """ if not (isinstance(tensor, paddle.Tensor) or - (isinstance(tensor, list) and all(isinstance(t, paddle.Tensor) for t in tensor))): - raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor))) + (isinstance(tensor, list) + and all(isinstance(t, paddle.Tensor) for t in tensor))): + raise TypeError('tensor or list of tensors expected, got {}'.format( + type(tensor))) # if list of tensors, convert to a 4D mini-batch Tensor if isinstance(tensor, list): @@ -105,19 +107,20 @@ def tensor2img(input_image, min_max=(-1., 1.), image_num=1, imtype=np.uint8): image_num (int) -- the convert iamge numbers imtype (type) -- the desired type of the converted numpy array """ - def processing(im, transpose=True): + def processing(img, transpose=True): """"processing one numpy image. Parameters: im (tensor) -- the input image numpy array """ - if im.shape[0] == 1: # grayscale to RGB - im = np.tile(im, (3, 1, 1)) - im = im.clip(min_max[0], min_max[1]) - im = (im - min_max[0]) / (min_max[1] - min_max[0]) - im = im * 255.0 # scaling - im = np.transpose(im, (1, 2, 0)) if transpose else im # tranpose - return im + if img.shape[0] == 1: # grayscale to RGB + img = np.tile(img, (3, 1, 1)) + img = img.clip(min_max[0], min_max[1]) + img = (img - min_max[0]) / (min_max[1] - min_max[0]) + if imtype == np.uint8: + img = img * 255.0 # scaling + img = np.transpose(img, (1, 2, 0)) if transpose else img # tranpose + return img if not isinstance(input_image, np.ndarray): image_numpy = input_image.numpy() # convert it into a numpy array @@ -143,6 +146,7 @@ def tensor2img(input_image, min_max=(-1., 1.), image_num=1, imtype=np.uint8): else: # if it is a numpy array, do nothing image_numpy = input_image + image_numpy = image_numpy.round() return image_numpy.astype(imtype)