未验证 提交 d76c74e4 编写于 作者: L LielinJiang 提交者: GitHub

Add esrgan model and refine codes (#104)

* add esrgan model, refine codes
上级 05483bee
epochs: 200 epochs: 200
output_dir: output_dir output_dir: output_dir
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model: model:
name: CycleGANModel name: CycleGANModel
...@@ -20,60 +17,96 @@ model: ...@@ -20,60 +17,96 @@ model:
n_layers: 3 n_layers: 3
norm_type: instance norm_type: instance
input_nc: 3 input_nc: 3
cycle_criterion:
name: L1Loss
idt_criterion:
name: L1Loss
loss_weight: 0.5
gan_criterion:
name: GANLoss
gan_mode: lsgan gan_mode: lsgan
dataset: dataset:
train: train:
name: UnpairedDataset name: UnpairedDataset
dataroot: data/cityscapes dataroot_a: data/cityscapes/trainA
dataroot_b: data/cityscapes/trainB
num_workers: 0 num_workers: 0
batch_size: 1 batch_size: 1
phase: train is_train: True
max_dataset_size: inf max_size: inf
direction: AtoB preprocess:
input_nc: 3 - name: LoadImageFromFile
output_nc: 3 key: A
serial_batches: False - name: LoadImageFromFile
pool_size: 50 key: B
transforms: - name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: RandomCrop - name: RandomCrop
size: [256, 256] size: [256, 256]
keys: ['image', 'image']
- name: RandomHorizontalFlip - name: RandomHorizontalFlip
prob: 0.5 prob: 0.5
keys: ['image', 'image']
- name: Transpose - name: Transpose
keys: ['image', 'image']
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
test: test:
name: SingleDataset name: UnpairedDataset
dataroot: data/cityscapes/testB dataroot_a: data/cityscapes/testA
max_dataset_size: inf dataroot_b: data/cityscapes/testB
direction: BtoA num_workers: 0
input_nc: 3 batch_size: 1
output_nc: 3 max_size: inf
serial_batches: False is_train: False
pool_size: 50 load_pipeline:
transforms: - name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: Transpose - name: Transpose
keys: ['image', 'image']
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG_A
- netG_B
beta1: 0.5
optimD:
name: Adam
net_names:
- netD_A
- netD_B
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
epochs: 200 epochs: 200
output_dir: output_dir output_dir: output_dir
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model: model:
name: CycleGANModel name: CycleGANModel
...@@ -20,60 +17,96 @@ model: ...@@ -20,60 +17,96 @@ model:
n_layers: 3 n_layers: 3
norm_type: instance norm_type: instance
input_nc: 3 input_nc: 3
cycle_criterion:
name: L1Loss
idt_criterion:
name: L1Loss
loss_weight: 0.5
gan_criterion:
name: GANLoss
gan_mode: lsgan gan_mode: lsgan
dataset: dataset:
train: train:
name: UnpairedDataset name: UnpairedDataset
dataroot: data/horse2zebra dataroot_a: data/horse2zebra/trainA
dataroot_b: data/horse2zebra/trainB
num_workers: 0 num_workers: 0
batch_size: 1 batch_size: 1
phase: train is_train: True
max_dataset_size: inf max_size: inf
direction: AtoB load_pipeline:
input_nc: 3 - name: LoadImageFromFile
output_nc: 3 key: A
serial_batches: False - name: LoadImageFromFile
pool_size: 50 key: B
transforms: - name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: RandomCrop - name: RandomCrop
size: [256, 256] size: [256, 256]
keys: ['image', 'image']
- name: RandomHorizontalFlip - name: RandomHorizontalFlip
prob: 0.5 prob: 0.5
keys: ['image', 'image']
- name: Transpose - name: Transpose
keys: ['image', 'image']
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
test: test:
name: SingleDataset name: UnpairedDataset
dataroot: data/horse2zebra/testA dataroot_a: data/horse2zebra/testA
max_dataset_size: inf dataroot_b: data/horse2zebra/testB
direction: AtoB num_workers: 0
input_nc: 3 batch_size: 1
output_nc: 3 max_size: inf
serial_batches: False is_train: False
pool_size: 50 load_pipeline:
transforms: - name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: Transpose - name: Transpose
keys: ['image', 'image']
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG_A
- netG_B
beta1: 0.5
optimD:
name: Adam
net_names:
- netD_A
- netD_B
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
...@@ -15,52 +15,70 @@ model: ...@@ -15,52 +15,70 @@ model:
norm_type: batch norm_type: batch
ndf: 64 ndf: 64
input_nc: 1 input_nc: 1
gan_mode: vanilla #wgangp gan_criterion:
name: GANLoss
gan_mode: vanilla
dataset: dataset:
train: train:
name: SingleDataset name: SingleDataset
dataroot: data/mnist/train dataroot: data/mnist/train
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 1
output_nc: 1
batch_size: 128 batch_size: 128
serial_batches: False preprocess:
transforms: - name: LoadImageFromFile
key: A
- name: Transfroms
input_keys: [A]
pipeline:
- name: Resize - name: Resize
size: [64, 64] size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose - name: Transpose
keys: [image, image]
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: [image, image]
test: test:
name: SingleDataset name: SingleDataset
dataroot: data/mnist/test dataroot: data/mnist/test
max_dataset_size: inf preprocess:
input_nc: 1 - name: LoadImageFromFile
output_nc: 1 key: A
serial_batches: False - name: Transforms
transforms: input_keys: [A]
pipeline:
- name: Resize - name: Resize
size: [64, 64] size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose - name: Transpose
keys: [image, image]
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.00002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
total_iters: 1000000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: BaseSRModel
generator:
name: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
pixel_criterion:
name: L1Loss
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers: 4
batch_size: 16
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 128
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: 0.0002
periods: [250000, 250000, 250000, 250000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: false
ssim:
name: SSIM
crop_border: 4
test_y_channel: false
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
total_iters: 250000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: ESRGAN
generator:
name: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
discriminator:
name: VGGDiscriminator128
in_channels: 3
num_feat: 64
pixel_criterion:
name: L1Loss
loss_weight: !!float 1e-2
perceptual_criterion:
name: PerceptualLoss
layer_weights:
'34': 1.0
perceptual_weight: 1.0
style_weight: 0.0
norm_img: False
gan_criterion:
name: GANLoss
gan_mode: vanilla
loss_weight: !!float 5e-3
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers: 6
batch_size: 32
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 128
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.0001
milestones: [50000, 100000, 200000, 300000]
gamma: 0.5
optimizer:
optimG:
name: Adam
net_names:
- generator
weight_decay: 0.0
beta1: 0.9
beta2: 0.99
optimD:
name: Adam
net_names:
- discriminator
weight_decay: 0.0
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: false
ssim:
name: SSIM
crop_border: 4
test_y_channel: false
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5000
epochs: 100 epochs: 100
output_dir: tmp output_dir: tmp
checkpoints_dir: checkpoints checkpoints_dir: checkpoints
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model: model:
name: MakeupModel name: MakeupModel
...@@ -17,6 +14,17 @@ model: ...@@ -17,6 +14,17 @@ model:
n_layers: 3 n_layers: 3
input_nc: 3 input_nc: 3
norm_type: spectral norm_type: spectral
cycle_criterion:
name: L1Loss
idt_criterion:
name: L1Loss
loss_weight: 0.5
l1_criterion:
name: L1Loss
l2_criterion:
name: MSELoss
gan_criterion:
name: GANLoss
gan_mode: lsgan gan_mode: lsgan
dataset: dataset:
...@@ -26,28 +34,42 @@ dataset: ...@@ -26,28 +34,42 @@ dataset:
dataroot: data/MT-Dataset dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup] cls_list: [non-makeup, makeup]
phase: train phase: train
pool_size: 16
test: test:
name: MakeupDataset name: MakeupDataset
trans_size: 256 trans_size: 256
dataroot: data/MT-Dataset dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup] cls_list: [non-makeup, makeup]
phase: test phase: test
pool_size: 16
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_DA:
name: Adam
net_names:
- netD_A
beta1: 0.5
optimizer_DB:
name: Adam
net_names:
- netD_B
beta1: 0.5
log_config: log_config:
interval: 10 interval: 10
visiual_interval: 500 visiual_interval: 500
snapshot_config: snapshot_config:
interval: 1 interval: 5
epochs: 200 epochs: 200
output_dir: output_dir output_dir: output_dir
lambda_L1: 100
model: model:
name: Pix2PixModel name: Pix2PixModel
...@@ -18,22 +17,29 @@ model: ...@@ -18,22 +17,29 @@ model:
n_layers: 3 n_layers: 3
input_nc: 6 input_nc: 6
norm_type: batch norm_type: batch
direction: b2a
pixel_criterion:
name: L1Loss
loss_weight: 100
gan_criterion:
name: GANLoss
gan_mode: vanilla gan_mode: vanilla
dataset: dataset:
train: train:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes/train
num_workers: 4 num_workers: 4
batch_size: 1 batch_size: 1
phase: train preprocess:
max_dataset_size: inf - name: LoadImageFromFile
direction: BtoA key: pair
input_nc: 3 - name: SplitPairedImage
output_nc: 3 key: pair
serial_batches: False paired_keys: [A, B]
pool_size: 0 - name: Transforms
transforms: input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
...@@ -52,15 +58,18 @@ dataset: ...@@ -52,15 +58,18 @@ dataset:
keys: [image, image] keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/test
phase: test num_workers: 4
max_dataset_size: inf batch_size: 1
direction: BtoA load_pipeline:
input_nc: 3 - name: LoadImageFromFile
output_nc: 3 key: pair
serial_batches: True - name: SplitPairedImage
pool_size: 50 key: pair
transforms: paired_keys: [A, B]
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
...@@ -72,16 +81,25 @@ dataset: ...@@ -72,16 +81,25 @@ dataset:
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: [image, image] keys: [image, image]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG
beta1: 0.5
optimD:
name: Adam
net_names:
- netD
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
epochs: 200 epochs: 200
output_dir: output_dir output_dir: output_dir
lambda_L1: 100
model: model:
name: Pix2PixModel name: Pix2PixModel
...@@ -18,22 +17,29 @@ model: ...@@ -18,22 +17,29 @@ model:
n_layers: 3 n_layers: 3
input_nc: 6 input_nc: 6
norm_type: batch norm_type: batch
direction: b2a
pixel_criterion:
name: L1Loss
loss_weight: 100
gan_criterion:
name: GANLoss
gan_mode: vanilla gan_mode: vanilla
dataset: dataset:
train: train:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes/train
num_workers: 0 num_workers: 4
batch_size: 1 batch_size: 1
phase: train preprocess:
max_dataset_size: inf - name: LoadImageFromFile
direction: BtoA key: pair
input_nc: 3 - name: SplitPairedImage
output_nc: 3 key: pair
serial_batches: False paired_keys: [A, B]
pool_size: 0 - name: Transforms
transforms: input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
...@@ -52,15 +58,15 @@ dataset: ...@@ -52,15 +58,15 @@ dataset:
keys: [image, image] keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/test
phase: test num_workers: 4
max_dataset_size: inf batch_size: 1
direction: BtoA load_pipeline:
input_nc: 3 - name: LoadImageFromFile
output_nc: 3 key: pair
serial_batches: True - name: Transforms
pool_size: 50 input_keys: [A, B]
transforms: pipeline:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
...@@ -72,15 +78,25 @@ dataset: ...@@ -72,15 +78,25 @@ dataset:
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: [image, image] keys: [image, image]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0004 learning_rate: 0.0004
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG
beta1: 0.5
optimD:
name: Adam
net_names:
- netD
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
epochs: 200 epochs: 200
output_dir: output_dir output_dir: output_dir
lambda_L1: 100
model: model:
name: Pix2PixModel name: Pix2PixModel
...@@ -18,22 +17,29 @@ model: ...@@ -18,22 +17,29 @@ model:
n_layers: 3 n_layers: 3
input_nc: 6 input_nc: 6
norm_type: batch norm_type: batch
direction: b2a
pixel_criterion:
name: L1Loss
loss_weight: 100
gan_criterion:
name: GANLoss
gan_mode: vanilla gan_mode: vanilla
dataset: dataset:
train: train:
name: PairedDataset name: PairedDataset
dataroot: data/facades/ dataroot: data/facades/train
num_workers: 0 num_workers: 4
batch_size: 1 batch_size: 1
phase: train preprocess:
max_dataset_size: inf - name: LoadImageFromFile
direction: BtoA key: pair
input_nc: 3 - name: SplitPairedImage
output_nc: 3 key: pair
serial_batches: False paired_keys: [A, B]
pool_size: 0 - name: Transforms
transforms: input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
...@@ -52,15 +58,15 @@ dataset: ...@@ -52,15 +58,15 @@ dataset:
keys: [image, image] keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/facades/ dataroot: data/facades/test
phase: test num_workers: 4
max_dataset_size: inf batch_size: 1
direction: BtoA load_pipeline:
input_nc: 3 - name: LoadImageFromFile
output_nc: 3 key: pair
serial_batches: True - name: Transforms
pool_size: 50 input_keys: [A, B]
transforms: pipeline:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
...@@ -72,15 +78,25 @@ dataset: ...@@ -72,15 +78,25 @@ dataset:
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: [image, image] keys: [image, image]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG
beta1: 0.5
optimD:
name: Adam
net_names:
- netD
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
epochs: 300 epochs: 300
output_dir: output_dir output_dir: output_dir
adv_weight: 1.0
cycle_weight: 10.0
identity_weight: 10.0
cam_weight: 1000.0
model: model:
name: UGATITModel name: UGATITModel
...@@ -25,57 +21,102 @@ model: ...@@ -25,57 +21,102 @@ model:
input_nc: 3 input_nc: 3
ndf: 64 ndf: 64
n_layers: 5 n_layers: 5
l1_criterion:
name: L1Loss
mse_criterion:
name: MSELoss
bce_criterion:
name: BCEWithLogitsLoss
adv_weight: 1.0
cycle_weight: 10.0
identity_weight: 10.0
cam_weight: 1000.0
dataset: dataset:
train: train:
name: UnpairedDataset name: UnpairedDataset
dataroot: data/selfie2anime dataroot_a: data/selfie2anime/trainA
dataroot_b: data/selfie2anime/trainB
num_workers: 0 num_workers: 0
phase: train batch_size: 1
max_dataset_size: inf is_train: True
direction: AtoB max_size: inf
input_nc: 3 preprocess:
output_nc: 3 - name: LoadImageFromFile
serial_batches: False key: A
transforms: - name: LoadImageFromFile
key: B
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 'bilinear' #'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: RandomCrop - name: RandomCrop
size: [256, 256] size: [256, 256]
keys: ['image', 'image']
- name: RandomHorizontalFlip - name: RandomHorizontalFlip
prob: 0.5 prob: 0.5
keys: ['image', 'image']
- name: Transpose - name: Transpose
keys: ['image', 'image']
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
test: test:
name: SingleDataset name: UnpairedDataset
dataroot: data/selfie2anime/testA dataroot_a: data/selfie2anime/testA
max_dataset_size: inf dataroot_b: data/selfie2anime/testB
direction: AtoB num_workers: 0
input_nc: 3 batch_size: 1
output_nc: 3 max_size: inf
serial_batches: False is_train: False
transforms: preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transfroms
input_keys: [A, B]
pipeline:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 'bilinear' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: Transpose - name: Transpose
keys: ['image', 'image']
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
optimizer:
name: Adam
beta1: 0.5
weight_decay: 0.0001
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0001 learning_rate: 0.0001
start_epoch: 150 start_epoch: 150
decay_epochs: 150 decay_epochs: 150
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- genA2B
- genB2A
weight_decay: 0.0001
beta1: 0.5
optimD:
name: Adam
net_names:
- disGA
- disGB
- disLA
- disLB
weight_decay: 0.0001
beta1: 0.5
log_config: log_config:
interval: 10 interval: 10
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .unpaired_dataset import UnpairedDataset from .unpaired_dataset import UnpairedDataset
from .single_dataset import SingleDataset from .single_dataset import SingleDataset
from .paired_dataset import PairedDataset from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset from .base_sr_dataset import SRDataset
from .makeup_dataset import MakeupDataset from .makeup_dataset import MakeupDataset
from .common_vision_dataset import CommonVisionDataset from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset from .animeganv2_dataset import AnimeGANV2Dataset
...@@ -12,105 +12,124 @@ ...@@ -12,105 +12,124 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix import os
import random from pathlib import Path
import numpy as np from abc import ABCMeta, abstractmethod
from paddle.io import Dataset from paddle.io import Dataset
from PIL import Image
import cv2
import paddle.vision.transforms as transforms from .preprocess import build_preprocess
from .transforms import transforms as T
from abc import ABC, abstractmethod
IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
'.PPM', '.bmp', '.BMP')
class BaseDataset(Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets. def scandir(dir_path, suffix=None, recursive=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str | obj:`Path`): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
Returns:
A generator for all the interested files with relative pathes.
""" """
def __init__(self, cfg): if isinstance(dir_path, (str, Path)):
"""Initialize the class; save the options in the class dir_path = str(dir_path)
else:
raise TypeError('"dir_path" must be a string or Path object')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = os.path.relpath(entry.path, root)
if suffix is None:
yield rel_path
elif rel_path.endswith(suffix):
yield rel_path
else:
if recursive:
yield from _scandir(entry.path,
suffix=suffix,
recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
class BaseDataset(Dataset, metaclass=ABCMeta):
"""Base class for datasets.
All datasets should subclass it.
All subclasses should overwrite:
``prepare_data_infos``, supporting to load information and generate
image lists.
Args: Args:
cfg (dict) -- stores all the experiment flags preprocess (list[dict]): A sequence of data preprocess config.
""" """
self.cfg = cfg def __init__(self, preprocess=None):
self.root = cfg.dataroot super(BaseDataset, self).__init__()
@abstractmethod if preprocess:
def __len__(self): self.preprocess = build_preprocess(preprocess)
"""Return the total number of images in the dataset."""
return 0
@abstractmethod @abstractmethod
def __getitem__(self, index): def prepare_data_infos(self):
"""Return a data point and its metadata information. """Abstract function for loading annotation.
All subclasses should overwrite this function
should set self.annotations in this fucntion
data_infos should be as list of dict:
[{key_path: file_path}, {key_path: file_path}, {key_path: file_path}]
"""
self.data_infos = None
Parameters: @staticmethod
index - - a random integer for data indexing def scan_folder(path):
"""Obtain sample path list (including sub-folders) from a given folder.
Args:
path (str|pathlib.Path): Folder path.
Returns: Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information. list[str]: sample list obtained form given folder.
""" """
pass
if isinstance(path, (str, Path)):
path = str(path)
def get_params(cfg, size):
w, h = size
new_h = h
new_w = w
if cfg.preprocess == 'resize_and_crop':
new_h = new_w = cfg.load_size
elif cfg.preprocess == 'scale_width_and_crop':
new_w = cfg.load_size
new_h = cfg.load_size * h // w
x = random.randint(0, np.maximum(0, new_w - cfg.crop_size))
y = random.randint(0, np.maximum(0, new_h - cfg.crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(cfg,
params=None,
grayscale=False,
method=cv2.INTER_CUBIC,
convert=True):
transform_list = []
if grayscale:
print('grayscale not support for now!!!')
pass
if 'resize' in cfg.preprocess:
osize = (cfg.load_size, cfg.load_size)
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in cfg.preprocess:
print('scale_width not support for now!!!')
pass
if 'crop' in cfg.preprocess:
if params is None:
transform_list.append(T.RandomCrop(cfg.crop_size))
else: else:
transform_list.append(T.Crop(params['crop_pos'], cfg.crop_size)) raise TypeError("'path' must be a str or a Path object, "
f'but received {type(path)}.')
if cfg.preprocess == 'none':
print('preprocess not support for now!!!') samples = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True))
pass samples = [os.path.join(path, v) for v in samples]
assert samples, '{} has no valid image file.'.format(path)
if not cfg.no_flip: return samples
if params is None:
transform_list.append(transforms.RandomHorizontalFlip()) def __getitem__(self, idx):
elif params['flip']: datas = self.data_infos[idx]
transform_list.append(transforms.RandomHorizontalFlip(1.0))
if hasattr(self, 'preprocess') and self.preprocess:
if convert: datas = self.preprocess(datas)
transform_list += [transforms.Permute(to_rgb=True)]
if cfg.get('normalize', None): return datas
transform_list += [
transforms.Normalize(cfg.normalize.mean, cfg.normalize.std) def __len__(self):
] """Length of the dataset.
return transforms.Compose(transform_list) Returns:
int: Length of the dataset.
"""
return len(self.data_infos)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
from pathlib import Path
from .base_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register()
class SRDataset(BaseDataset):
"""Base super resulotion dataset for image restoration."""
def __init__(self,
lq_folder,
gt_folder,
preprocess,
scale,
filename_tmpl='{}'):
super(SRDataset, self).__init__(preprocess)
self.lq_folder = lq_folder
self.gt_folder = gt_folder
self.scale = scale
self.filename_tmpl = filename_tmpl
self.prepare_data_infos()
def prepare_data_infos(self):
"""Load annoations for SR dataset.
It loads the LQ and GT image path from folders.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
self.data_infos = []
lq_paths = self.scan_folder(self.lq_folder)
gt_paths = self.scan_folder(self.gt_folder)
assert len(lq_paths) == len(gt_paths), (
f'gt and lq datasets have different number of images: '
f'{len(lq_paths)}, {len(gt_paths)}.')
for gt_path in gt_paths:
basename, ext = os.path.splitext(os.path.basename(gt_path))
lq_path = os.path.join(self.lq_folder,
(f'{self.filename_tmpl.format(basename)}'
f'{ext}'))
assert lq_path in lq_paths, f'{lq_path} is not in lq_paths.'
self.data_infos.append(dict(lq_path=lq_path, gt_path=gt_path))
return self.data_infos
...@@ -66,14 +66,21 @@ class DictDataset(paddle.io.Dataset): ...@@ -66,14 +66,21 @@ class DictDataset(paddle.io.Dataset):
class DictDataLoader(): class DictDataLoader():
def __init__(self, dataset, batch_size, is_train, num_workers=4): def __init__(self,
dataset,
batch_size,
is_train,
num_workers=4,
distributed=True):
self.dataset = DictDataset(dataset) self.dataset = DictDataset(dataset)
place = paddle.CUDAPlace(ParallelEnv().dev_id) \ place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0) if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
sampler = DistributedBatchSampler(self.dataset, if distributed:
sampler = DistributedBatchSampler(
self.dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True if is_train else False, shuffle=True if is_train else False,
drop_last=True if is_train else False) drop_last=True if is_train else False)
...@@ -82,6 +89,14 @@ class DictDataLoader(): ...@@ -82,6 +89,14 @@ class DictDataLoader():
batch_sampler=sampler, batch_sampler=sampler,
places=place, places=place,
num_workers=num_workers) num_workers=num_workers)
else:
self.dataloader = paddle.io.DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False,
places=place,
num_workers=num_workers)
self.batch_size = batch_size self.batch_size = batch_size
...@@ -117,12 +132,20 @@ class DictDataLoader(): ...@@ -117,12 +132,20 @@ class DictDataLoader():
return current_items return current_items
def build_dataloader(cfg, is_train=True): def build_dataloader(cfg, is_train=True, distributed=True):
dataset = DATASETS.get(cfg.name)(cfg) cfg_ = cfg.copy()
batch_size = cfg_.pop('batch_size', 1)
num_workers = cfg_.pop('num_workers', 0)
name = cfg_.pop('name')
batch_size = cfg.get('batch_size', 1) dataset = DATASETS.get(name)(**cfg_)
num_workers = cfg.get('num_workers', 0)
dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers) dataloader = DictDataLoader(dataset,
batch_size,
is_train,
num_workers,
distributed=distributed)
return dataloader return dataloader
...@@ -13,35 +13,38 @@ ...@@ -13,35 +13,38 @@
# limitations under the License. # limitations under the License.
import cv2 import cv2
import os.path
from .base_dataset import BaseDataset, get_transform
from .transforms.makeup_transforms import get_makeup_transform
import paddle.vision.transforms as T
from PIL import Image
import random import random
import os.path
import numpy as np import numpy as np
from PIL import Image
import paddle
import paddle.vision.transforms as T
from .base_dataset import BaseDataset
from ..utils.preprocess import * from ..utils.preprocess import *
from .builder import DATASETS from .builder import DATASETS
@DATASETS.register() @DATASETS.register()
class MakeupDataset(BaseDataset): class MakeupDataset(paddle.io.Dataset):
def __init__(self, cfg): def __init__(self, dataroot, phase, trans_size, cls_list):
"""Initialize this dataset class. """Initialize psgan dataset class.
Parameters: Args:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions dataroot (str): Directory of dataset.
phase (str): 'train' or 'test'.
""" """
BaseDataset.__init__(self, cfg) self.image_path = dataroot
self.image_path = cfg.dataroot self.mode = phase
self.mode = cfg.phase self.trans_size = trans_size
self.transform = get_makeup_transform(cfg) self.cls_list = cls_list
self.transform = self.build_makeup_transform()
self.norm = T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]) self.norm = T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
self.transform_mask = get_makeup_transform(cfg, pic="mask") self.transform_mask = self.build_makeup_transform("mask")
self.trans_size = cfg.trans_size self.trans_size = trans_size
self.cls_list = cfg.cls_list
self.cls_A = self.cls_list[0] self.cls_A = self.cls_list[0]
self.cls_B = self.cls_list[1] self.cls_B = self.cls_list[1]
for cls in self.cls_list: for cls in self.cls_list:
...@@ -72,6 +75,18 @@ class MakeupDataset(BaseDataset): ...@@ -72,6 +75,18 @@ class MakeupDataset(BaseDataset):
getattr(self, cls + "_mask_filenames").append(splits[1]) getattr(self, cls + "_mask_filenames").append(splits[1])
getattr(self, cls + "_lmks_filenames").append(splits[2]) getattr(self, cls + "_lmks_filenames").append(splits[2])
def build_makeup_transform(self, pic="image"):
if pic == "image":
transform = T.Compose([
T.Resize(size=self.trans_size),
T.Transpose(),
])
else:
transform = T.Resize(size=self.trans_size,
interpolation=cv2.INTER_NEAREST)
return transform
def __getitem__(self, index): def __getitem__(self, index):
"""Return MANet and MDNet needed params. """Return MANet and MDNet needed params.
......
...@@ -12,65 +12,35 @@ ...@@ -12,65 +12,35 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cv2
import paddle
import os.path
from .base_dataset import BaseDataset, get_params, get_transform
from .image_folder import make_dataset
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms from .base_dataset import BaseDataset
@DATASETS.register() @DATASETS.register()
class PairedDataset(BaseDataset): class PairedDataset(BaseDataset):
"""A dataset class for paired image dataset. """A dataset class for paired image dataset.
""" """
def __init__(self, cfg): def __init__(self, dataroot, preprocess):
"""Initialize this dataset class. """Initialize this dataset class.
Args: Args:
cfg (dict): configs of datasets. dataroot (str): Directory of dataset.
""" preprocess (list[dict]): A sequence of data preprocess config.
BaseDataset.__init__(self, cfg)
self.dir_AB = os.path.join(cfg.dataroot,
cfg.phase) # get the image directory
self.AB_paths = sorted(make_dataset(
self.dir_AB, cfg.max_dataset_size)) # get image paths
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
self.transforms = build_transforms(cfg.transforms)
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) - - an image in the input domain
B (tensor) - - its corresponding image in the target domain
A_paths (str) - - image paths
B_paths (str) - - image paths (same as A_paths)
""" """
# read a image given a random integer index super(PairedDataset, self).__init__(preprocess)
AB_path = self.AB_paths[index] self.dataroot = dataroot
AB = cv2.cvtColor(cv2.imread(AB_path), cv2.COLOR_BGR2RGB) self.data_infos = self.prepare_data_infos()
# split AB image into A and B
h, w = AB.shape[:2]
# w, h = AB.size
w2 = int(w / 2)
A = AB[:h, :w2, :] def prepare_data_infos(self):
B = AB[:h, w2:, :] """Load paired image paths.
# apply the same transform to both A and B Returns:
A, B = self.transforms((A, B)) list[dict]: List that contains paired image paths.
"""
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} data_infos = []
pair_paths = sorted(self.scan_folder(self.dataroot))
for pair_path in pair_paths:
data_infos.append(dict(pair_path=pair_path))
def __len__(self): return data_infos
"""Return the total number of images in the dataset."""
return len(self.AB_paths)
from .io import LoadImageFromFile
from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip,
PairedRandomVerticalFlip, PairedRandomTransposeHW,
SRPairedRandomCrop, SplitPairedImage)
from .builder import build_preprocess
...@@ -12,52 +12,53 @@ ...@@ -12,52 +12,53 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict import copy
import paddle import traceback
import paddle.nn as nn
from .generators.builder import build_generator from ...utils.registry import Registry, build_from_config
from .base_model import BaseModel
from .losses import GANLoss
from .builder import MODELS
LOAD_PIPELINE = Registry("LOAD_PIPELINE")
TRANSFORMS = Registry("TRANSFORM")
PREPROCESS = Registry("PREPROCESS")
@MODELS.register()
class SRGANModel(BaseModel):
def __init__(self, cfg):
super(SRGANModel, self).__init__(cfg)
# define networks class Compose(object):
self.model_names = ['G'] """
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
functions (list[callable]): List of functions to compose.
self.netG = build_generator(cfg.model.generator) Returns:
self.visual_names = ['LQ', 'GT', 'fake_H'] A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
# TODO: support srgan train. """
if False: def __init__(self, functions):
# self.netD = build_discriminator(cfg.model.discriminator) self.functions = functions
self.netG.train()
# self.netD.train()
def set_input(self, input): def __call__(self, datas):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters: for func in self.functions:
input (dict): include the data itself and its metadata information. try:
datas = func(datas)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform fuction [{}] with error: "
"{} and stack:\n{}".format(func, e, str(stack_info)))
raise RuntimeError
return datas
The option 'direction' can be used to swap images in domain A and domain B.
"""
# AtoB = self.opt.dataset.train.direction == 'AtoB' def build_preprocess(cfg):
if 'A' in input: preproccess = []
self.LQ = paddle.to_tensor(input['A']) if not isinstance(cfg, (list, tuple)):
if 'B' in input: cfg = [cfg]
self.GT = paddle.to_tensor(input['B'])
if 'A_paths' in input:
self.image_paths = input['A_paths']
def forward(self): for cfg_ in cfg:
self.fake_H = self.netG(self.LQ) process = build_from_config(cfg_, PREPROCESS)
preproccess.append(process)
def optimize_parameters(self, step): preproccess = Compose(preproccess)
pass return preproccess
import cv2
from .builder import PREPROCESS
@PREPROCESS.register()
class LoadImageFromFile(object):
"""Load image from file.
Args:
key (str): Keys in results to find corresponding path. Default: 'image'.
flag (str): Loading flag for images. Default: -1.
to_rgb (str): Convert img to 'rgb' format. Default: True.
backend (str): io backend where images are store. Default: None.
save_original_img (bool): If True, maintain a copy of the image in
`results` dict with name of `f'ori_{key}'`. Default: False.
kwargs (dict): Args for file client.
"""
def __init__(self,
key='image',
flag=-1,
to_rgb=True,
save_original_img=False,
backend=None,
**kwargs):
self.key = key
self.flag = flag
self.to_rgb = to_rgb
self.backend = backend
self.save_original_img = save_original_img
self.kwargs = kwargs
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
filepath = str(results[f'{self.key}_path'])
#TODO: use file client to manage io backend
# such as opencv, pil, imdb
img = cv2.imread(filepath, self.flag)
if self.to_rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results[self.key] = img
results[f'{self.key}_path'] = filepath
results[f'{self.key}_ori_shape'] = img.shape
if self.save_original_img:
results[f'ori_{self.key}'] = img.copy()
return results
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numbers
import collections
import paddle.vision.transforms as T
import paddle.vision.transforms.functional as F
from .builder import TRANSFORMS, build_from_config
from .builder import PREPROCESS
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
TRANSFORMS.register(T.Resize)
TRANSFORMS.register(T.RandomCrop)
TRANSFORMS.register(T.RandomHorizontalFlip)
TRANSFORMS.register(T.RandomVerticalFlip)
TRANSFORMS.register(T.Normalize)
TRANSFORMS.register(T.Transpose)
@PREPROCESS.register()
class Transforms():
def __init__(self, pipeline, input_keys):
self.input_keys = input_keys
self.transforms = []
for transform_cfg in pipeline:
self.transforms.append(build_from_config(transform_cfg, TRANSFORMS))
def __call__(self, datas):
data = []
for k in self.input_keys:
data.append(datas[k])
data = tuple(data)
for transform in self.transforms:
data = transform(data)
if hasattr(transform, 'params') and isinstance(
transform.params, dict):
datas.update(transform.params)
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
return datas
@PREPROCESS.register()
class SplitPairedImage:
def __init__(self, key, paired_keys=['A', 'B']):
self.key = key
self.paired_keys = paired_keys
def __call__(self, datas):
# split AB image into A and B
h, w = datas[self.key].shape[:2]
# w, h = AB.size
w2 = int(w / 2)
a, b = self.paired_keys
datas[a] = datas[self.key][:h, :w2, :]
datas[b] = datas[self.key][:h, w2:, :]
datas[a + '_path'] = datas[self.key + '_path']
datas[b + '_path'] = datas[self.key + '_path']
return datas
@TRANSFORMS.register()
class PairedRandomCrop(T.RandomCrop):
def __init__(self, size, keys=None):
super().__init__(size, keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
params = {}
params['crop_prams'] = self._get_param(image, self.size)
return params
def _apply_image(self, img):
i, j, h, w = self.params['crop_prams']
return F.crop(img, i, j, h, w)
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys)
def _get_params(self, inputs):
params = {}
params['flip'] = random.random() < self.prob
return params
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
return image
@TRANSFORMS.register()
class PairedRandomVerticalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys)
def _get_params(self, inputs):
params = {}
params['flip'] = random.random() < self.prob
return params
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
return image
@TRANSFORMS.register()
class PairedRandomTransposeHW(T.BaseTransform):
"""Randomly transpose images in H and W dimensions with a probability.
(TransposeHW = horizontal flip + anti-clockwise rotatation by 90 degrees)
When used with horizontal/vertical flips, it serves as a way of rotation
augmentation.
It also supports randomly transposing a list of images.
Required keys are the keys in attributes "keys", added or modified keys are
"transpose" and the keys in attributes "keys".
Args:
prob (float): The propability to transpose the images.
keys (list[str]): The images to be transposed.
"""
def __init__(self, prob=0.5, keys=None):
self.keys = keys
self.prob = prob
def _get_params(self, inputs):
params = {}
params['transpose'] = random.random() < self.prob
return params
def _apply_image(self, image):
if self.params['transpose']:
image = image.transpose(1, 0, 2)
return image
@TRANSFORMS.register()
class SRPairedRandomCrop(T.BaseTransform):
"""Super resolution random crop.
It crops a pair of lq and gt images with corresponding locations.
It also supports accepting lq list and gt list.
Required keys are "scale", "lq", and "gt",
added or modified keys are "lq" and "gt".
Args:
scale (int): model upscale factor.
gt_patch_size (int): cropped gt patch size.
"""
def __init__(self, scale, gt_patch_size, keys=None):
self.gt_patch_size = gt_patch_size
self.scale = scale
self.keys = keys
def __call__(self, inputs):
"""inputs must be (lq_img, gt_img)"""
scale = self.scale
lq_patch_size = self.gt_patch_size // scale
lq = inputs[0]
gt = inputs[1]
h_lq, w_lq, _ = lq.shape
h_gt, w_gt, _ = gt.shape
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError('scale size not match')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError('lq size error')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
lq = lq[top:top + lq_patch_size, left:left + lq_patch_size, ...]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
gt = gt[top_gt:top_gt + self.gt_patch_size,
left_gt:left_gt + self.gt_patch_size, ...]
outputs = (lq, gt)
return outputs
...@@ -12,54 +12,34 @@ ...@@ -12,54 +12,34 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cv2 from .base_dataset import BaseDataset
import paddle
from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class SingleDataset(BaseDataset): class SingleDataset(BaseDataset):
""" """
""" """
def __init__(self, cfg): def __init__(self, dataroot, preprocess):
"""Initialize this dataset class. """Initialize single dataset class.
Args: Args:
cfg (dict) -- stores all the experiment flags dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config.
""" """
BaseDataset.__init__(self, cfg) super(SingleDataset).__init__(self, preprocess)
self.A_paths = sorted(make_dataset(cfg.dataroot, cfg.max_dataset_size)) self.dataroot = dataroot
input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc self.data_infos = self.prepare_data_infos()
self.transform = build_transforms(self.cfg.transforms)
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters: def prepare_data_infos(self):
index - - a random integer for data indexing """prepare image paths from a folder.
Returns a dictionary that contains A and A_paths Returns:
A(tensor) - - an image in one domain list[dict]: List that contains paired image paths.
A_paths(str) - - the path of the image
""" """
A_path = self.A_paths[index] data_infos = []
A_img = cv2.cvtColor(cv2.imread(A_path), cv2.COLOR_BGR2RGB) paths = sorted(self.scan_folder(self.dataroot))
A = self.transform(A_img) for path in paths:
data_infos.append(dict(A_path=path))
return {'A': A, 'A_paths': A_path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.A_paths)
def get_path_by_indexs(self, indexs): return data_infos
if isinstance(indexs, paddle.Tensor):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
current_paths.append(self.A_paths[index])
return current_paths
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import random
import numpy as np
import paddle.vision.transforms as transform
from pathlib import Path
from paddle.io import Dataset
from .builder import DATASETS
def scandir(dir_path, suffix=None, recursive=False):
"""Scan a directory to find the interested files.
"""
if isinstance(dir_path, (str, Path)):
dir_path = str(dir_path)
else:
raise TypeError('"dir_path" must be a string or Path object')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = os.path.relpath(entry.path, root)
if suffix is None:
yield rel_path
elif rel_path.endswith(suffix):
yield rel_path
else:
if recursive:
yield from _scandir(entry.path,
suffix=suffix,
recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
"""
assert len(folders) == 2, (
'The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, (
'The len of keys should be 2 with [input_key, gt_key]. '
f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
assert len(input_paths) == len(gt_paths), (
f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for gt_path in gt_paths:
basename, ext = os.path.splitext(os.path.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = os.path.join(input_folder, input_name)
assert input_name in input_paths, (f'{input_name} is not in '
f'{input_key}_paths.')
gt_path = os.path.join(gt_folder, gt_path)
paths.append(
dict([(f'{input_key}_path', input_path),
(f'{gt_key}_path', gt_path)]))
return paths
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
"""Paired random crop.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(
f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
img_lqs = [
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
for v in img_lqs
]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
for v in img_gts
]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip:
cv2.flip(img, 1, img)
if vflip:
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip:
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip:
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
return imgs
@DATASETS.register()
class SRImageDataset(Dataset):
"""Paired image dataset for image restoration."""
def __init__(self, cfg):
super(SRImageDataset, self).__init__()
self.cfg = cfg
self.file_client = None
self.io_backend_opt = cfg['io_backend']
self.gt_folder, self.lq_folder = cfg['dataroot_gt'], cfg['dataroot_lq']
if 'filename_tmpl' in cfg:
self.filename_tmpl = cfg['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
#TODO: LielinJiang support lmdb to accelerate io
pass
elif 'meta_info_file' in self.cfg and self.cfg[
'meta_info_file'] is not None:
#TODO: LielinJiang support lmdb to accelerate io
pass
else:
self.paths = paired_paths_from_folder(
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.filename_tmpl)
def __getitem__(self, index):
scale = self.cfg['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
lq_path = self.paths[index]['lq_path']
img_gt = cv2.imread(gt_path).astype(np.float32) / 255.
img_lq = cv2.imread(lq_path).astype(np.float32) / 255.
# augmentation for training
if self.cfg['phase'] == 'train':
gt_size = self.cfg['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.cfg['use_flip'],
self.cfg['use_rot'])
# TODO: color space transform
# BGR to RGB, HWC to CHW, numpy to tensor
permute = transform.Permute()
img_gt = permute(img_gt)
img_lq = permute(img_lq)
return {
'lq': img_lq,
'gt': img_gt,
'lq_path': lq_path,
'gt_path': gt_path
}
def __len__(self):
return len(self.paths)
...@@ -24,14 +24,11 @@ class Compose(object): ...@@ -24,14 +24,11 @@ class Compose(object):
""" """
Composes several transforms together use for composing list of transforms Composes several transforms together use for composing list of transforms
together for a dataset transform. together for a dataset transform.
Args: Args:
transforms (list): List of transforms to compose. transforms (list): List of transforms to compose.
Returns: Returns:
A compose object which is callable, __call__ for this Compose A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely. object will call each given :attr:`transforms` sequencely.
""" """
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
......
...@@ -12,80 +12,68 @@ ...@@ -12,80 +12,68 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cv2
import random import random
import os.path import os.path
from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset
from .base_dataset import BaseDataset
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class UnpairedDataset(BaseDataset): class UnpairedDataset(BaseDataset):
""" """
""" """
def __init__(self, cfg): def __init__(self, dataroot_a, dataroot_b, max_size, is_train, preprocess):
"""Initialize this dataset class. """Initialize unpaired dataset class.
Args: Args:
cfg (dict) -- stores all the experiment flags dataroot_a (str): Directory of dataset a.
""" dataroot_b (str): Directory of dataset b.
BaseDataset.__init__(self, cfg) max_size (int): max size of dataset size.
self.dir_A = os.path.join(cfg.dataroot, cfg.phase + is_train (int): whether in train mode.
'A') # create a path '/path/to/data/trainA' preprocess (list[dict]): A sequence of data preprocess config.
self.dir_B = os.path.join(cfg.dataroot, cfg.phase +
'B') # create a path '/path/to/data/trainB'
self.A_paths = sorted(make_dataset(
self.dir_A,
cfg.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(
self.dir_B,
cfg.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
btoA = self.cfg.direction == 'BtoA'
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
self.transform_A = build_transforms(self.cfg.transforms)
self.transform_B = build_transforms(self.cfg.transforms)
self.reset_paths()
def reset_paths(self):
self.path_dict = {}
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
""" """
A_path = self.A_paths[ super(UnpairedDataset, self).__init__(preprocess)
index % self.A_size] # make sure index is within then range self.dir_A = os.path.join(dataroot_a)
if self.cfg.serial_batches: # make sure index is within then range self.dir_B = os.path.join(dataroot_b)
index_B = index % self.B_size self.is_train = is_train
else: # randomize the index for domain B to avoid fixed pairs. self.data_infos_a = self.prepare_data_infos(self.dir_A)
index_B = random.randint(0, self.B_size - 1) self.data_infos_b = self.prepare_data_infos(self.dir_B)
B_path = self.B_paths[index_B] self.size_a = len(self.data_infos_a)
self.size_b = len(self.data_infos_b)
def prepare_data_infos(self, dataroot):
"""Load unpaired image paths of one domain.
A_img = cv2.cvtColor(cv2.imread(A_path), cv2.COLOR_BGR2RGB) Args:
B_img = cv2.cvtColor(cv2.imread(B_path), cv2.COLOR_BGR2RGB) dataroot (str): Path to the folder root for unpaired images of
# apply image transformation one domain.
A = self.transform_A(A_img)
B = self.transform_B(B_img)
# return A, B Returns:
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} list[dict]: List that contains unpaired image paths of one domain.
"""
data_infos = []
paths = sorted(self.scan_folder(dataroot))
for path in paths:
data_infos.append(dict(path=path))
return data_infos
def __getitem__(self, idx):
if self.is_train:
img_a_path = self.data_infos_a[idx % self.size_a]['path']
idx_b = random.randint(0, self.size_b)
img_b_path = self.data_infos_b[idx_b]['path']
datas = dict(A_path=img_a_path, B_path=img_b_path)
else:
img_a_path = self.data_infos_a[idx % self.size_a]['path']
img_b_path = self.data_infos_b[idx % self.size_b]['path']
datas = dict(A_path=img_a_path, B_path=img_b_path)
if self.preprocess:
datas = self.preprocess(datas)
return datas
def __len__(self): def __len__(self):
"""Return the total number of images in the dataset. """Return the total number of images in the dataset.
...@@ -93,4 +81,4 @@ class UnpairedDataset(BaseDataset): ...@@ -93,4 +81,4 @@ class UnpairedDataset(BaseDataset):
As we have two datasets with potentially different number of images, As we have two datasets with potentially different number of images,
we take a maximum of we take a maximum of
""" """
return max(self.A_size, self.B_size) return max(self.size_a, self.size_b)
...@@ -27,25 +27,78 @@ from ..models.builder import build_model ...@@ -27,25 +27,78 @@ from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import makedirs, save, load from ..utils.filesystem import makedirs, save, load
from ..utils.timer import TimeAverager from ..utils.timer import TimeAverager
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
class Trainer: class IterLoader:
def __init__(self, cfg): def __init__(self, dataloader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 1
# build train dataloader @property
self.train_dataloader = build_dataloader(cfg.dataset.train) def epoch(self):
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
if 'lr_scheduler' in cfg.optimizer: return data
cfg.optimizer.lr_scheduler.step_per_epoch = len(
self.train_dataloader) def __len__(self):
return len(self._dataloader)
class Trainer:
"""
# trainer calling logic:
#
# build_model || model(BaseModel)
# | ||
# build_dataloader || dataloader
# | ||
# model.setup_lr_schedulers || lr_scheduler
# | ||
# model.setup_optimizers || optimizers
# | ||
# train loop (model.setup_input + model.train_iter) || train loop
# | ||
# print log (model.get_current_losses) ||
# | ||
# save checkpoint (model.nets) \/
"""
def __init__(self, cfg):
# build model # build model
self.model = build_model(cfg) self.model = build_model(cfg.model)
# multiple gpus prepare # multiple gpus prepare
if ParallelEnv().nranks > 1: if ParallelEnv().nranks > 1:
self.distributed_data_parallel() self.distributed_data_parallel()
# build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train)
self.iters_per_epoch = len(self.train_dataloader)
# build lr scheduler
# TODO: has a better way?
if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)
# build optimizers
self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
cfg.optimizer)
# build metrics
self.metrics = None
validate_cfg = cfg.get('validate', None)
if validate_cfg and 'metrics' in validate_cfg:
self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.enable_visualdl = cfg.get('enable_visualdl', False) self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl: if self.enable_visualdl:
...@@ -54,9 +107,18 @@ class Trainer: ...@@ -54,9 +107,18 @@ class Trainer:
# base config # base config
self.output_dir = cfg.output_dir self.output_dir = cfg.output_dir
self.epochs = cfg.epochs self.epochs = cfg.get('epochs', None)
if self.epochs:
self.total_iters = self.epochs * self.iters_per_epoch
self.by_epoch = True
else:
self.by_epoch = False
self.total_iters = cfg.total_iters
self.start_epoch = 1 self.start_epoch = 1
self.current_epoch = 1 self.current_epoch = 1
self.current_iter = 1
self.inner_iter = 1
self.batch_id = 0 self.batch_id = 0
self.global_steps = 0 self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval self.weight_interval = cfg.snapshot_config.interval
...@@ -69,10 +131,6 @@ class Trainer: ...@@ -69,10 +131,6 @@ class Trainer:
self.local_rank = ParallelEnv().local_rank self.local_rank = ParallelEnv().local_rank
# time count
self.steps_per_epoch = len(self.train_dataloader)
self.total_steps = self.epochs * self.steps_per_epoch
self.time_count = {} self.time_count = {}
self.best_metric = {} self.best_metric = {}
...@@ -85,22 +143,24 @@ class Trainer: ...@@ -85,22 +143,24 @@ class Trainer:
reader_cost_averager = TimeAverager() reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager() batch_cost_averager = TimeAverager()
for epoch in range(self.start_epoch, self.epochs + 1): iter_loader = IterLoader(self.train_dataloader)
self.current_epoch = epoch
while self.current_iter < (self.total_iters + 1):
self.current_epoch = iter_loader.epoch
self.inner_iter = self.current_iter % self.iters_per_epoch
start_time = step_start_time = time.time() start_time = step_start_time = time.time()
for i, data in enumerate(self.train_dataloader): data = next(iter_loader)
reader_cost_averager.record(time.time() - step_start_time) reader_cost_averager.record(time.time() - step_start_time)
self.batch_id = i
# unpack data from dataset and apply preprocessing # unpack data from dataset and apply preprocessing
# data input should be dict # data input should be dict
self.model.set_input(data) self.model.setup_input(data)
self.model.optimize_parameters() self.model.train_iter(self.optimizers)
batch_cost_averager.record(time.time() - step_start_time, batch_cost_averager.record(time.time() - step_start_time,
num_samples=self.cfg.get( num_samples=self.cfg.get(
'batch_size', 1)) 'batch_size', 1))
if i % self.log_interval == 0: if self.current_iter % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average() self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average() self.step_time = batch_cost_averager.get_average()
self.ips = batch_cost_averager.get_ips_average() self.ips = batch_cost_averager.get_ips_average()
...@@ -109,93 +169,42 @@ class Trainer: ...@@ -109,93 +169,42 @@ class Trainer:
reader_cost_averager.reset() reader_cost_averager.reset()
batch_cost_averager.reset() batch_cost_averager.reset()
if i % self.visual_interval == 0: if self.current_iter % self.visual_interval == 0:
self.visual('visual_train') self.visual('visual_train')
self.global_steps += 1
step_start_time = time.time() step_start_time = time.time()
self.logger.info(
'train one epoch use time: {:.3f} seconds.'.format(time.time() -
start_time))
if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate()
self.model.lr_scheduler.step() self.model.lr_scheduler.step()
if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1)
self.save(epoch)
def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val,
is_train=False)
metric_result = {}
for i, data in enumerate(self.val_dataloader):
self.batch_id = i
self.model.set_input(data)
self.model.test()
visual_results = {}
current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()
for j in range(len(current_paths)): if self.by_epoch:
short_path = os.path.basename(current_paths[j]) temp = self.current_epoch
basename = os.path.splitext(short_path)[0]
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]})
if 'psnr' in self.cfg.validate.metrics:
if 'psnr' not in metric_result:
metric_result['psnr'] = calculate_psnr(
tensor2img(current_visuals['output'][j], (0., 1.)),
tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.psnr)
else:
metric_result['psnr'] += calculate_psnr(
tensor2img(current_visuals['output'][j], (0., 1.)),
tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.psnr)
if 'ssim' in self.cfg.validate.metrics:
if 'ssim' not in metric_result:
metric_result['ssim'] = calculate_ssim(
tensor2img(current_visuals['output'][j], (0., 1.)),
tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.ssim)
else: else:
metric_result['ssim'] += calculate_ssim( temp = self.current_iter
tensor2img(current_visuals['output'][j], (0., 1.)), if self.validate_interval > -1 and temp % self.validate_interval == 0:
tensor2img(current_visuals['gt'][j], (0., 1.)), self.test()
**self.cfg.validate.metrics.ssim)
self.visual('visual_val', if temp % self.weight_interval == 0:
visual_results=visual_results, self.save(temp, 'weight', keep=-1)
step=self.batch_id) self.save(temp)
if i % self.log_interval == 0: self.current_iter += 1
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
self.logger.info('Epoch {} validate end: {}'.format(
self.current_epoch, metric_result))
def test(self): def test(self):
if not hasattr(self, 'test_dataloader'): if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test, self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False) is_train=False,
distributed=False)
if self.metrics:
for metric in self.metrics.values():
metric.reset()
# data[0]: img, data[1]: img path index # data[0]: img, data[1]: img path index
# test batch size must be 1 # test batch size must be 1
for i, data in enumerate(self.test_dataloader): for i, data in enumerate(self.test_dataloader):
self.batch_id = i
self.model.set_input(data) self.model.setup_input(data)
self.model.test() self.model.test_iter(metrics=self.metrics)
visual_results = {} visual_results = {}
current_paths = self.model.get_image_paths() current_paths = self.model.get_image_paths()
...@@ -217,11 +226,23 @@ class Trainer: ...@@ -217,11 +226,23 @@ class Trainer:
self.logger.info('Test iter: [%d/%d]' % self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader))) (i, len(self.test_dataloader)))
if self.metrics:
for metric_name, metric in self.metrics.items():
self.logger.info("Metric {}: {:.4f}".format(
metric_name, metric.accumulate()))
def print_log(self): def print_log(self):
losses = self.model.get_current_losses() losses = self.model.get_current_losses()
message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id)
message += '%s: %.6f ' % ('lr', self.current_learning_rate) message = ''
if self.by_epoch:
message += 'Epoch: %d/%d, iter: %d/%d ' % (
self.current_epoch, self.epochs, self.inner_iter,
self.iters_per_epoch)
else:
message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)
message += f'lr: {self.current_learning_rate:.3e} '
for k, v in losses.items(): for k, v in losses.items():
message += '%s: %.3f ' % (k, v) message += '%s: %.3f ' % (k, v)
...@@ -238,9 +259,7 @@ class Trainer: ...@@ -238,9 +259,7 @@ class Trainer:
message += 'ips: %.5f images/s ' % self.ips message += 'ips: %.5f images/s ' % self.ips
if hasattr(self, 'step_time'): if hasattr(self, 'step_time'):
cur_step = self.steps_per_epoch * (self.current_epoch - eta = self.step_time * (self.total_iters - self.current_iter - 1)
1) + self.batch_id
eta = self.step_time * (self.total_steps - cur_step - 1)
eta_str = str(datetime.timedelta(seconds=int(eta))) eta_str = str(datetime.timedelta(seconds=int(eta)))
message += f'eta: {eta_str}' message += f'eta: {eta_str}'
...@@ -274,6 +293,7 @@ class Trainer: ...@@ -274,6 +293,7 @@ class Trainer:
min_max = self.cfg.get('min_max', None) min_max = self.cfg.get('min_max', None)
if min_max is None: if min_max is None:
min_max = (-1., 1.) min_max = (-1., 1.)
image_num = self.cfg.get('image_num', None) image_num = self.cfg.get('image_num', None)
if (image_num is None) or (not self.enable_visualdl): if (image_num is None) or (not self.enable_visualdl):
image_num = 1 image_num = 1
...@@ -345,6 +365,14 @@ class Trainer: ...@@ -345,6 +365,14 @@ class Trainer:
state_dicts = load(weight_path) state_dicts = load(weight_path)
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name])
self.logger.info(
'Loaded pretrained weight for net {}'.format(net_name))
else:
self.logger.warning(
'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.format(net_name, net_name))
net.set_state_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
def close(self): def close(self):
......
English (./README.md)
# Usage
To compute the FID score between two datasets, where images of each dataset are contained in an individual folder:
wget https://paddlegan.bj.bcebos.com/InceptionV3.pdparams
```
python test_fid_score.py --image_data_path1 /path/to/dataset1 --image_data_path2 /path/to/dataset2 --inference_model ./InceptionV3.pdparams
```
### Inception-V3 weights converted from torchvision
Download: https://aistudio.baidu.com/aistudio/datasetdetail/51890
This model weights file is converted from official torchvision inception-v3 model. And both BigGAN and StarGAN-v2 is using it to calculate FID score.
Note that this model weights is different from above one (which is converted from tensorflow unofficial version)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
"'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
return img
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
#img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
#out_img = _convert_output_type_range(out_img, img_type)
return out_img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
from compute_fid import *
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--image_data_path1',
type=str,
default='./real',
help='path of image data')
parser.add_argument('--image_data_path2',
type=str,
default='./fake',
help='path of image data')
parser.add_argument('--inference_model',
type=str,
default='./pretrained/params_inceptionV3',
help='path of inference_model.')
parser.add_argument('--use_gpu',
type=bool,
default=True,
help='default use gpu.')
parser.add_argument('--batch_size',
type=int,
default=1,
help='sample number in a batch for inference.')
parser.add_argument(
'--style',
type=str,
help='calculation style: stargan or default (gan-compression style)')
args = parser.parse_args()
return args
def main():
args = parse_args()
path1 = args.image_data_path1
path2 = args.image_data_path2
paths = (path1, path2)
inference_model_path = args.inference_model
batch_size = args.batch_size
fid_value = calculate_fid_given_paths(paths,
inference_model_path,
batch_size,
args.use_gpu,
2048,
style=args.style)
print('FID: ', fid_value)
if __name__ == "__main__":
main()
...@@ -11,3 +11,6 @@ ...@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .psnr_ssim import PSNR, SSIM
from .builder import build_metric
...@@ -12,4 +12,16 @@ ...@@ -12,4 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip, Add, ResizeToScale import copy
import paddle
from ..utils.registry import Registry
METRICS = Registry("METRIC")
def build_metric(cfg):
cfg_ = cfg.copy()
name = cfg_.pop('name', None)
metric = METRICS.get(name)(**cfg_)
return metric
...@@ -14,8 +14,59 @@ ...@@ -14,8 +14,59 @@
import cv2 import cv2
import numpy as np import numpy as np
import paddle
from .metric_util import reorder_image, to_y_channel from .builder import METRICS
@METRICS.register()
class PSNR(paddle.metric.Metric):
def __init__(self, crop_border, input_order='HWC', test_y_channel=False):
self.crop_border = crop_border
self.input_order = input_order
self.test_y_channel = test_y_channel
self.reset()
def reset(self):
self.results = []
def update(self, preds, gts):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
for pred, gt in zip(preds, gts):
value = calculate_psnr(pred, gt, self.crop_border, self.input_order,
self.test_y_channel)
self.results.append(value)
def accumulate(self):
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
def name(self):
return 'PSNR'
@METRICS.register()
class SSIM(PSNR):
def update(self, preds, gts):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
for pred, gt in zip(preds, gts):
value = calculate_ssim(pred, gt, self.crop_border, self.input_order,
self.test_y_channel)
self.results.append(value)
def name(self):
return 'SSIM'
def calculate_psnr(img1, def calculate_psnr(img1,
...@@ -46,6 +97,8 @@ def calculate_psnr(img1, ...@@ -46,6 +97,8 @@ def calculate_psnr(img1,
raise ValueError( raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are ' f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"') '"HWC" and "CHW"')
img1 = img1.copy().astype('float32')
img2 = img2.copy().astype('float32')
img1 = reorder_image(img1, input_order=input_order) img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order) img2 = reorder_image(img2, input_order=input_order)
...@@ -134,6 +187,10 @@ def calculate_ssim(img1, ...@@ -134,6 +187,10 @@ def calculate_ssim(img1,
raise ValueError( raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are ' f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"') '"HWC" and "CHW"')
img1 = img1.copy().astype('float32')[..., ::-1]
img2 = img2.copy().astype('float32')[..., ::-1]
img1 = reorder_image(img1, input_order=input_order) img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order) img2 = reorder_image(img2, input_order=input_order)
...@@ -149,3 +206,81 @@ def calculate_ssim(img1, ...@@ -149,3 +206,81 @@ def calculate_ssim(img1,
for i in range(img1.shape[2]): for i in range(img1.shape[2]):
ssims.append(_ssim(img1[..., i], img2[..., i])) ssims.append(_ssim(img1[..., i], img2[..., i]))
return np.array(ssims).mean() return np.array(ssims).mean()
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
"'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
return img
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
return out_img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
...@@ -16,9 +16,9 @@ from .base_model import BaseModel ...@@ -16,9 +16,9 @@ from .base_model import BaseModel
from .gan_model import GANModel from .gan_model import GANModel
from .cycle_gan_model import CycleGANModel from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel from .sr_model import BaseSRModel
from .sr_model import SRModel
from .makeup_model import MakeupModel from .makeup_model import MakeupModel
from .esrgan_model import ESRGAN
from .ugatit_model import UGATITModel from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel from .dc_gan_model import DCGANModel
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
...@@ -19,7 +19,7 @@ from .base_model import BaseModel ...@@ -19,7 +19,7 @@ from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .losses import GANLoss from .criterions.gan_loss import GANLoss
from ..modules.caffevgg import CaffeVGG19 from ..modules.caffevgg import CaffeVGG19
from ..solver import build_optimizer from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
......
...@@ -19,24 +19,40 @@ import numpy as np ...@@ -19,24 +19,40 @@ import numpy as np
from collections import OrderedDict from collections import OrderedDict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ..solver.lr_scheduler import build_lr_scheduler from .criterions.builder import build_criterion
from ..solver import build_lr_scheduler, build_optimizer
from ..metrics import build_metric
from ..utils.visual import tensor2img
class BaseModel(ABC): class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models. """This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions: To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- <__init__>: initialize the class.
-- <set_input>: unpack data from dataset and apply preprocessing. -- <setup_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results. -- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights. -- <train_iter>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
# trainer training logic:
#
# build_model || model(BaseModel)
# | ||
# build_dataloader || dataloader
# | ||
# model.setup_lr_schedulers || lr_scheduler
# | ||
# model.setup_optimizers || optimizers
# | ||
# train loop (model.setup_input + model.train_iter) || train loop
# | ||
# print log (model.get_current_losses) ||
# | ||
# save checkpoint (model.nets) \/
""" """
def __init__(self, cfg): def __init__(self):
"""Initialize the BaseModel class. """Initialize the BaseModel class.
Args:
cfg (Dict)-- configs of Model.
When creating your custom class, you need to implement your own initialization. When creating your custom class, you need to implement your own initialization.
In this function, you should first call <super(YourClass, self).__init__(self, cfg)> In this function, you should first call <super(YourClass, self).__init__(self, cfg)>
Then, you need to define four lists: Then, you need to define four lists:
...@@ -47,57 +63,85 @@ class BaseModel(ABC): ...@@ -47,57 +63,85 @@ class BaseModel(ABC):
If two networks are updated at the same time, you can use itertools.chain to group them. If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example. See cycle_gan_model.py for an example.
""" """
self.cfg = cfg
self.is_train = cfg.is_train
self.save_dir = os.path.join(
cfg.output_dir,
cfg.model.name) # save all the checkpoints to save_dir
self.losses = OrderedDict()
self.nets = OrderedDict() self.nets = OrderedDict()
self.visual_items = OrderedDict()
self.optimizers = OrderedDict() self.optimizers = OrderedDict()
self.image_paths = [] self.metrics = OrderedDict()
self.metric = 0 # used for learning rate policy 'plateau' self.losses = OrderedDict()
self.visual_items = OrderedDict()
@abstractmethod @abstractmethod
def set_input(self, input): def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters: Args:
input (dict): includes the data itself and its metadata information. input (dict): includes the data itself and its metadata information.
""" """
pass pass
@abstractmethod @abstractmethod
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <train_iter> and <test_iter>."""
pass pass
@abstractmethod @abstractmethod
def optimize_parameters(self): def train_iter(self, optims=None):
"""Calculate losses, gradients, and update network weights; called in every training iteration""" """Calculate losses, gradients, and update network weights; called in every training iteration"""
pass pass
def build_lr_scheduler(self): def test_iter(self, metrics=None):
self.lr_scheduler = build_lr_scheduler(self.cfg.lr_scheduler) """Calculate metrics; called in every test iteration"""
self.eval()
with paddle.no_grad():
self.forward()
self.train()
def setup_train_mode(self, is_train):
self.is_train = is_train
def setup_lr_schedulers(self, cfg):
self.lr_scheduler = build_lr_scheduler(cfg)
return self.lr_scheduler
def setup_optimizers(self, lr, cfg):
if cfg.get('name', None):
cfg_ = cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
self.optimizers['optim'] = build_optimizer(cfg_, lr, parameters)
else:
for opt_name, opt_cfg in cfg.items():
cfg_ = opt_cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
self.optimizers[opt_name] = build_optimizer(
cfg_, lr, parameters)
return self.optimizers
def setup_metrics(self, cfg):
if isinstance(list(cfg.values())[0], dict):
for metric_name, cfg_ in cfg.items():
self.metrics[metric_name] = build_metric(cfg_)
else:
metric = build_metric(cfg)
self.metrics[metric.__class__.__name__] = metric
return self.metrics
def eval(self): def eval(self):
"""Make models eval mode during test time""" """Make nets eval mode during test time"""
for name in self.model_names: for net in self.nets.values():
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval() net.eval()
def test(self): def train(self):
"""Forward function used in test time. """Make nets train mode during train time"""
for net in self.nets.values():
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop net.train()
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self): def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization""" """Calculate additional output images for visdom and HTML visualization"""
...@@ -118,8 +162,8 @@ class BaseModel(ABC): ...@@ -118,8 +162,8 @@ class BaseModel(ABC):
def set_requires_grad(self, nets, requires_grad=False): def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Args: Args:
nets (network list) -- a list of networks nets (network list): a list of networks
requires_grad (bool) -- whether the networks require gradients or not requires_grad (bool): whether the networks require gradients or not
""" """
if not isinstance(nets, list): if not isinstance(nets, list):
nets = [nets] nets = [nets]
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import paddle import paddle
from ..utils.registry import Registry from ..utils.registry import Registry
...@@ -20,5 +21,7 @@ MODELS = Registry("MODEL") ...@@ -20,5 +21,7 @@ MODELS = Registry("MODEL")
def build_model(cfg): def build_model(cfg):
model = MODELS.get(cfg.model.name)(cfg) cfg_ = cfg.copy()
name = cfg_.pop('name', None)
model = MODELS.get(name)(**cfg_)
return model return model
from .gan_loss import GANLoss
from .perceptual_loss import PerceptualLoss
from .pixel_loss import L1Loss, MSELoss
from .builder import build_criterion
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.registry import Registry
CRITERIONS = Registry('CRITERION')
def build_criterion(cfg):
cfg_ = cfg.copy()
name = cfg_.pop('name')
try:
criterion = CRITERIONS.get(name)(**cfg_)
except Exception as e:
cls_ = CRITERIONS.get(name)
raise RuntimeError('class {} {}'.format(cls_.__name__, e))
return criterion
...@@ -16,29 +16,40 @@ import numpy as np ...@@ -16,29 +16,40 @@ import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from .builder import CRITERIONS
import paddle.nn.functional as F import paddle.nn.functional as F
@CRITERIONS.register()
class GANLoss(nn.Layer): class GANLoss(nn.Layer):
"""Define different GAN objectives. """Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input. that has the same size as the input.
""" """
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): def __init__(self,
gan_mode,
target_real_label=1.0,
target_fake_label=0.0,
loss_weight=1.0):
""" Initialize the GANLoss class. """ Initialize the GANLoss class.
Parameters: Args:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. gan_mode (str): the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image target_real_label (bool): label for a real image
target_fake_label (bool) - - label of a fake image target_fake_label (bool): label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator. Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
""" """
super(GANLoss, self).__init__() super(GANLoss, self).__init__()
# when loss weight less than zero return None
if loss_weight <= 0:
return None
self.target_real_label = target_real_label self.target_real_label = target_real_label
self.target_fake_label = target_fake_label self.target_fake_label = target_fake_label
self.loss_weight = loss_weight
self.gan_mode = gan_mode self.gan_mode = gan_mode
if gan_mode == 'lsgan': if gan_mode == 'lsgan':
...@@ -53,7 +64,7 @@ class GANLoss(nn.Layer): ...@@ -53,7 +64,7 @@ class GANLoss(nn.Layer):
def get_target_tensor(self, prediction, target_is_real): def get_target_tensor(self, prediction, target_is_real):
"""Create label tensors with the same size as the input. """Create label tensors with the same size as the input.
Parameters: Args:
prediction (tensor) - - tpyically the prediction from a discriminator prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images target_is_real (bool) - - if the ground truth label is for real images or fake images
...@@ -75,13 +86,16 @@ class GANLoss(nn.Layer): ...@@ -75,13 +86,16 @@ class GANLoss(nn.Layer):
dtype='float32') dtype='float32')
target_tensor = self.target_fake_tensor target_tensor = self.target_fake_tensor
# target_tensor.stop_gradient = True
return target_tensor return target_tensor
def __call__(self, prediction, target_is_real, is_updating_D=None): def __call__(self,
prediction,
target_is_real,
is_disc=False,
is_updating_D=None):
"""Calculate loss given Discriminator's output and grount truth labels. """Calculate loss given Discriminator's output and grount truth labels.
Parameters: Args:
prediction (tensor) - - tpyically the prediction output from a discriminator prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images target_is_real (bool) - - if the ground truth label is for real images or fake images
is_updating_D (bool) - - if we are in updating D step or not is_updating_D (bool) - - if we are in updating D step or not
...@@ -108,4 +122,5 @@ class GANLoss(nn.Layer): ...@@ -108,4 +122,5 @@ class GANLoss(nn.Layer):
loss = F.softplus(-prediction).mean() loss = F.softplus(-prediction).mean()
else: else:
loss = F.softplus(prediction).mean() loss = F.softplus(prediction).mean()
return loss
return loss if is_disc else loss * self.loss_weight
import paddle
import paddle.nn as nn
import paddle.vision.models.vgg as vgg
from ppgan.utils.download import get_path_from_url
from .builder import CRITERIONS
class PerceptualVGG(nn.Layer):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): According to the name in this list,
forward function will return the corresponding features. This
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_tyep (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
pretrained_url (str): Path for pretrained weights. Default:
"""
def __init__(
self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
pretrained_url='https://paddlegan.bj.bcebos.com/model/vgg19.pdparams'
):
super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
# get vgg model and load pretrained vgg weight
_vgg = getattr(vgg, vgg_type)()
if pretrained_url:
weight_path = get_path_from_url(pretrained_url)
state_dict = paddle.load(weight_path)
_vgg.load_dict(state_dict)
print('PerceptualVGG loaded pretrained weight.')
num_layers = max(map(int, layer_name_list)) + 1
assert len(_vgg.features) >= num_layers
# only borrow layers that will be used from _vgg to avoid unused params
self.vgg_layers = nn.Sequential(
*list(_vgg.features.children())[:num_layers])
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer(
'mean',
paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]))
# the std is for image with range [-1, 1]
self.register_buffer(
'std',
paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]))
for v in self.vgg_layers.parameters():
v.trainable = False
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for name, module in self.vgg_layers.named_children():
x = module(x)
if name in self.layer_name_list:
output[name] = x.clone()
return output
@CRITERIONS.register()
class PerceptualLoss(nn.Layer):
"""Perceptual loss with commonly used style loss.
Args:
layers_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
5th, 10th and 18th feature layer will be extracted with weight 1.0
in calculting losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplified by the
weight. Default: 1.0.
style_weight (flaot): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplified by the weight.
Default: 1.0.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward fucntion of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
pretrained (str): Path for pretrained weights. Default:
"""
def __init__(
self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
perceptual_weight=1.0,
style_weight=1.0,
norm_img=True,
pretrained='https://paddlegan.bj.bcebos.com/model/vgg19.pdparams',
criterion='l1'):
super(PerceptualLoss, self).__init__()
# when loss weight less than zero return None
if perceptual_weight <= 0 and style_weight <= 0:
return None
self.norm_img = norm_img
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
pretrained_url=pretrained)
if criterion == 'l1':
self.criterion = nn.L1Loss()
else:
raise NotImplementedError(
f'{criterion} criterion has not been supported in'
' this version.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate preceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss *= self.perceptual_weight
else:
percep_loss = None
# calculate style loss
if self.style_weight > 0:
style_loss = 0
for k in x_features.keys():
style_loss += self.criterion(self._gram_mat(
x_features[k]), self._gram_mat(
gt_features[k])) * self.layer_weights[k]
style_loss *= self.style_weight
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (paddle.Tensor): Tensor with shape of (n, c, h, w).
Returns:
paddle.Tensor: Gram matrix.
"""
(n, c, h, w) = x.shape
features = x.reshape([n, c, w * h])
features_t = features.transpose([1, 2])
gram = features.bmm(features_t) / (c * h * w)
return gram
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
from .builder import CRITERIONS
@CRITERIONS.register()
class L1Loss():
"""L1 (mean absolute error, MAE) loss.
Args:
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
return None
self._l1_loss = nn.L1Loss(reduction)
self.loss_weight = loss_weight
self.reduction = reduction
def __call__(self, pred, target, **kwargs):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * self._l1_loss(pred, target)
@CRITERIONS.register()
class MSELoss():
"""MSE (L2) loss.
Args:
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
return None
self._l2_loss = nn.MSELoss(reduction)
self.loss_weight = loss_weight
self.reduction = reduction
def __call__(self, pred, target, **kwargs):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * self._l2_loss(pred, target)
@CRITERIONS.register()
class BCEWithLogitsLoss():
"""BCE loss.
Args:
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
return None
self._bce_loss = nn.BCEWithLogitsLoss(reduction=reduction)
self.loss_weight = loss_weight
self.reduction = reduction
def __call__(self, pred, target, **kwargs):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * self._bce_loss(pred, target)
...@@ -18,9 +18,8 @@ from .base_model import BaseModel ...@@ -18,9 +18,8 @@ from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .losses import GANLoss from .criterions import build_criterion
from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
from ..utils.image_pool import ImagePool from ..utils.image_pool import ImagePool
...@@ -30,61 +29,62 @@ class CycleGANModel(BaseModel): ...@@ -30,61 +29,62 @@ class CycleGANModel(BaseModel):
""" """
This class implements the CycleGAN model, for learning image-to-image translation without paired data. This class implements the CycleGAN model, for learning image-to-image translation without paired data.
The model training requires '--dataset_mode unaligned' dataset.
By default, it uses a '--netG resnet_9blocks' ResNet generator,
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
and a least-square GANs objective ('--gan_mode lsgan').
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
""" """
def __init__(self, cfg): def __init__(self,
generator,
discriminator=None,
cycle_criterion=None,
idt_criterion=None,
gan_criterion=None,
pool_size=50,
direction='a2b',
lambda_a=10.,
lambda_b=10.):
"""Initialize the CycleGAN class. """Initialize the CycleGAN class.
Parameters: Args:
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict generator (dict): config of generator.
discriminator (dict): config of discriminator.
cycle_criterion (dict): config of cycle criterion.
""" """
super(CycleGANModel, self).__init__(cfg) super(CycleGANModel, self).__init__()
self.direction = direction
# define networks (both Generators and discriminators) self.lambda_a = lambda_a
self.lambda_b = lambda_b
# define generators
# The naming is different from those used in the paper. # The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.nets['netG_A'] = build_generator(cfg.model.generator) self.nets['netG_A'] = build_generator(generator)
self.nets['netG_B'] = build_generator(cfg.model.generator) self.nets['netG_B'] = build_generator(generator)
init_weights(self.nets['netG_A']) init_weights(self.nets['netG_A'])
init_weights(self.nets['netG_B']) init_weights(self.nets['netG_B'])
if self.is_train: # define discriminators # define discriminators
self.nets['netD_A'] = build_discriminator(cfg.model.discriminator) if discriminator:
self.nets['netD_B'] = build_discriminator(cfg.model.discriminator) self.nets['netD_A'] = build_discriminator(discriminator)
self.nets['netD_B'] = build_discriminator(discriminator)
init_weights(self.nets['netD_A']) init_weights(self.nets['netD_A'])
init_weights(self.nets['netD_B']) init_weights(self.nets['netD_B'])
if self.is_train:
if cfg.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert (
cfg.dataset.train.input_nc == cfg.dataset.train.output_nc)
# create image buffer to store previously generated images # create image buffer to store previously generated images
self.fake_A_pool = ImagePool(cfg.dataset.train.pool_size) self.fake_A_pool = ImagePool(pool_size)
# create image buffer to store previously generated images # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(cfg.dataset.train.pool_size) self.fake_B_pool = ImagePool(pool_size)
# define loss functions # define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode) if gan_criterion:
self.criterionCycle = paddle.nn.L1Loss() self.gan_criterion = build_criterion(gan_criterion)
self.criterionIdt = paddle.nn.L1Loss()
if cycle_criterion:
self.build_lr_scheduler() self.cycle_criterion = build_criterion(cycle_criterion)
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer, if idt_criterion:
self.lr_scheduler, self.idt_criterion = build_criterion(idt_criterion)
parameter_list=self.nets['netG_A'].parameters() +
self.nets['netG_B'].parameters()) def setup_input(self, input):
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD_A'].parameters() +
self.nets['netD_B'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
Args: Args:
...@@ -92,8 +92,8 @@ class CycleGANModel(BaseModel): ...@@ -92,8 +92,8 @@ class CycleGANModel(BaseModel):
The option 'direction' can be used to swap domain A and domain B. The option 'direction' can be used to swap domain A and domain B.
""" """
mode = 'train' if self.is_train else 'test'
AtoB = self.cfg.dataset[mode].direction == 'AtoB' AtoB = self.direction == 'a2b'
if AtoB: if AtoB:
if 'A' in input: if 'A' in input:
...@@ -134,20 +134,22 @@ class CycleGANModel(BaseModel): ...@@ -134,20 +134,22 @@ class CycleGANModel(BaseModel):
def backward_D_basic(self, netD, real, fake): def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator """Calculate GAN loss for the discriminator
Parameters: Args:
netD (network) -- the discriminator D netD (Layer): the discriminator D
real (tensor array) -- real images real (paddle.Tensor): real images
fake (tensor array) -- images generated by a generator fake (paddle.Tensor): images generated by a generator
Return:
the discriminator loss.
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients. We also call loss_D.backward() to calculate the gradients.
""" """
# Real # Real
pred_real = netD(real) pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True) loss_D_real = self.gan_criterion(pred_real, True)
# Fake # Fake
pred_fake = netD(fake.detach()) pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False) loss_D_fake = self.gan_criterion(pred_fake, False)
# Combined loss and calculate gradients # Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D = (loss_D_real + loss_D_fake) * 0.5
...@@ -170,16 +172,13 @@ class CycleGANModel(BaseModel): ...@@ -170,16 +172,13 @@ class CycleGANModel(BaseModel):
def backward_G(self): def backward_G(self):
"""Calculate the loss for generators G_A and G_B""" """Calculate the loss for generators G_A and G_B"""
lambda_idt = self.cfg.lambda_identity
lambda_A = self.cfg.lambda_A
lambda_B = self.cfg.lambda_B
# Identity loss # Identity loss
if lambda_idt > 0: if self.idt_criterion:
# G_A should be identity if real_B is fed: ||G_A(B) - B|| # G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.nets['netG_A'](self.real_B) self.idt_A = self.nets['netG_A'](self.real_B)
self.loss_idt_A = self.criterionIdt( self.loss_idt_A = self.idt_criterion(self.idt_A,
self.idt_A, self.real_B) * lambda_B * lambda_idt self.real_B) * self.lambda_b
# G_B should be identity if real_A is fed: ||G_B(A) - A|| # G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.nets['netG_B'](self.real_A) self.idt_B = self.nets['netG_B'](self.real_A)
...@@ -187,24 +186,24 @@ class CycleGANModel(BaseModel): ...@@ -187,24 +186,24 @@ class CycleGANModel(BaseModel):
self.visual_items['idt_A'] = self.idt_A self.visual_items['idt_A'] = self.idt_A
self.visual_items['idt_B'] = self.idt_B self.visual_items['idt_B'] = self.idt_B
self.loss_idt_B = self.criterionIdt( self.loss_idt_B = self.idt_criterion(self.idt_B,
self.idt_B, self.real_A) * lambda_A * lambda_idt self.real_A) * self.lambda_a
else: else:
self.loss_idt_A = 0 self.loss_idt_A = 0
self.loss_idt_B = 0 self.loss_idt_B = 0
# GAN loss D_A(G_A(A)) # GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_B), self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_B),
True) True)
# GAN loss D_B(G_B(B)) # GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_A), self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_A),
True) True)
# Forward cycle loss || G_B(G_A(A)) - A|| # Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.loss_cycle_A = self.cycle_criterion(self.rec_A,
self.real_A) * lambda_A self.real_A) * self.lambda_a
# Backward cycle loss || G_A(G_B(B)) - B|| # Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.loss_cycle_B = self.cycle_criterion(self.rec_B,
self.real_B) * lambda_B self.real_B) * self.lambda_b
self.losses['G_idt_A_loss'] = self.loss_idt_A self.losses['G_idt_A_loss'] = self.loss_idt_A
self.losses['G_idt_B_loss'] = self.loss_idt_B self.losses['G_idt_B_loss'] = self.loss_idt_B
...@@ -217,7 +216,7 @@ class CycleGANModel(BaseModel): ...@@ -217,7 +216,7 @@ class CycleGANModel(BaseModel):
self.loss_G.backward() self.loss_G.backward()
def optimize_parameters(self): def train_iter(self, optimizers=None):
"""Calculate losses, gradients, and update network weights; called in every training iteration""" """Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward # forward
# compute fake images and reconstruction images. # compute fake images and reconstruction images.
...@@ -227,19 +226,19 @@ class CycleGANModel(BaseModel): ...@@ -227,19 +226,19 @@ class CycleGANModel(BaseModel):
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
False) False)
# set G_A and G_B's gradients to zero # set G_A and G_B's gradients to zero
self.optimizers['optimizer_G'].clear_grad() optimizers['optimG'].clear_grad()
# calculate gradients for G_A and G_B # calculate gradients for G_A and G_B
self.backward_G() self.backward_G()
# update G_A and G_B's weights # update G_A and G_B's weights
self.optimizers['optimizer_G'].step() self.optimizers['optimG'].step()
# D_A and D_B # D_A and D_B
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True) self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True)
# set D_A and D_B's gradients to zero # set D_A and D_B's gradients to zero
self.optimizers['optimizer_D'].clear_grad() optimizers['optimD'].clear_grad()
# calculate gradients for D_A # calculate gradients for D_A
self.backward_D_A() self.backward_D_A()
# calculate graidents for D_B # calculate graidents for D_B
self.backward_D_B() self.backward_D_B()
# update D_A and D_B's weights # update D_A and D_B's weights
self.optimizers['optimizer_D'].step() optimizers['optimD'].step()
...@@ -18,69 +18,53 @@ from .base_model import BaseModel ...@@ -18,69 +18,53 @@ from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .losses import GANLoss from .criterions import build_criterion
from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
@MODELS.register() @MODELS.register()
class DCGANModel(BaseModel): class DCGANModel(BaseModel):
""" This class implements the DCGAN model, for learning a distribution from input images. """
This class implements the DCGAN model, for learning a distribution from input images.
The model training requires dataset.
By default, it uses a '--netG DCGenerator' generator,
a '--netD DCDiscriminator' discriminator,
and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
DCGAN paper: https://arxiv.org/pdf/1511.06434 DCGAN paper: https://arxiv.org/pdf/1511.06434
""" """
def __init__(self, cfg): def __init__(self, generator, discriminator=None, gan_criterion=None):
"""Initialize the DCGAN class. """Initialize the DCGAN class.
Args:
Parameters: generator (dict): config of generator.
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict discriminator (dict): config of discriminator.
pixel_criterion (dict): config of pixel criterion.
gan_criterion (dict): config of gan criterion.
""" """
super(DCGANModel, self).__init__(cfg) super(DCGANModel, self).__init__()
self.gen_cfg = generator
# define networks (both generator and discriminator) # define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator) self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG']) init_weights(self.nets['netG'])
self.cfg = cfg
if self.is_train: if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator) self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD']) init_weights(self.nets['netD'])
if self.is_train: if gan_criterion:
self.losses = {} self.gan_criterion = build_criterion(gan_criterion)
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode) def setup_input(self, input):
# build optimizers
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters: Args:
input (dict): include the data itself and its metadata information. input (dict): include the data itself and its metadata information.
""" """
# get 1-channel gray image, or 3-channel color image # get 1-channel gray image, or 3-channel color image
self.real = paddle.to_tensor(input['A'][:,0:self.cfg.model.generator.input_nc,:,:]) self.real = paddle.to_tensor(input['A'])
self.image_paths = input['A_paths'] self.image_paths = input['A_path']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <train_iter> and <test_iter>."""
# generate random noise and fake image # generate random noise and fake image
self.z = paddle.rand(shape=(self.real.shape[0],self.cfg.model.generator.input_nz,1,1)) self.z = paddle.rand(shape=(self.real.shape[0], self.gen_cfg.input_nz,
1, 1))
self.fake = self.nets['netG'](self.z) self.fake = self.nets['netG'](self.z)
# put items to visual dict # put items to visual dict
...@@ -91,10 +75,10 @@ class DCGANModel(BaseModel): ...@@ -91,10 +75,10 @@ class DCGANModel(BaseModel):
"""Calculate GAN loss for the discriminator""" """Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake # Fake; stop backprop to the generator by detaching fake
pred_fake = self.nets['netD'](self.fake.detach()) pred_fake = self.nets['netD'](self.fake.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False) self.loss_D_fake = self.gan_criterion(pred_fake, False)
pred_real = self.nets['netD'](self.real) pred_real = self.nets['netD'](self.real)
self.loss_D_real = self.criterionGAN(pred_real, True) self.loss_D_real = self.gan_criterion(pred_real, True)
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
...@@ -108,7 +92,7 @@ class DCGANModel(BaseModel): ...@@ -108,7 +92,7 @@ class DCGANModel(BaseModel):
"""Calculate GAN loss for the generator""" """Calculate GAN loss for the generator"""
# G(A) should fake the discriminator # G(A) should fake the discriminator
pred_fake = self.nets['netD'](self.fake) pred_fake = self.nets['netD'](self.fake)
self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G_GAN = self.gan_criterion(pred_fake, True)
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_G = self.loss_G_GAN self.loss_G = self.loss_G_GAN
...@@ -117,7 +101,7 @@ class DCGANModel(BaseModel): ...@@ -117,7 +101,7 @@ class DCGANModel(BaseModel):
self.losses['G_adv_loss'] = self.loss_G_GAN self.losses['G_adv_loss'] = self.loss_G_GAN
def optimize_parameters(self): def train_iter(self, optimizers=None):
# compute fake images: G(A) # compute fake images: G(A)
self.forward() self.forward()
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .vgg_discriminator import VGGDiscriminator128
from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification
from .discriminator_ugatit import UGATITDiscriminator from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator from .dcdiscriminator import DCDiscriminator
......
import paddle.nn as nn
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class VGGDiscriminator128(nn.Layer):
"""VGG style discriminator with input size 128 x 128.
It is used to train SRGAN and ESRGAN.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features.
Default: 64.
"""
def __init__(self, in_channels, num_feat, norm_layer='batch'):
super(VGGDiscriminator128, self).__init__()
self.conv0_0 = nn.Conv2D(in_channels, num_feat, 3, 1, 1, bias_attr=True)
self.conv0_1 = nn.Conv2D(num_feat, num_feat, 4, 2, 1, bias_attr=False)
self.bn0_1 = nn.BatchNorm2D(num_feat)
self.conv1_0 = nn.Conv2D(num_feat,
num_feat * 2,
3,
1,
1,
bias_attr=False)
self.bn1_0 = nn.BatchNorm2D(num_feat * 2)
self.conv1_1 = nn.Conv2D(num_feat * 2,
num_feat * 2,
4,
2,
1,
bias_attr=False)
self.bn1_1 = nn.BatchNorm2D(num_feat * 2)
self.conv2_0 = nn.Conv2D(num_feat * 2,
num_feat * 4,
3,
1,
1,
bias_attr=False)
self.bn2_0 = nn.BatchNorm2D(num_feat * 4)
self.conv2_1 = nn.Conv2D(num_feat * 4,
num_feat * 4,
4,
2,
1,
bias_attr=False)
self.bn2_1 = nn.BatchNorm2D(num_feat * 4)
self.conv3_0 = nn.Conv2D(num_feat * 4,
num_feat * 8,
3,
1,
1,
bias_attr=False)
self.bn3_0 = nn.BatchNorm2D(num_feat * 8)
self.conv3_1 = nn.Conv2D(num_feat * 8,
num_feat * 8,
4,
2,
1,
bias_attr=False)
self.bn3_1 = nn.BatchNorm2D(num_feat * 8)
self.conv4_0 = nn.Conv2D(num_feat * 8,
num_feat * 8,
3,
1,
1,
bias_attr=False)
self.bn4_0 = nn.BatchNorm2D(num_feat * 8)
self.conv4_1 = nn.Conv2D(num_feat * 8,
num_feat * 8,
4,
2,
1,
bias_attr=False)
self.bn4_1 = nn.BatchNorm2D(num_feat * 8)
self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
self.linear2 = nn.Linear(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x):
assert x.shape[2] == 128 and x.shape[3] == 128, (
f'Input spatial size must be 128x128, '
f'but received {x.shape}.')
feat = self.lrelu(self.conv0_0(x))
feat = self.lrelu(self.bn0_1(
self.conv0_1(feat))) # output spatial size: (64, 64)
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
feat = self.lrelu(self.bn1_1(
self.conv1_1(feat))) # output spatial size: (32, 32)
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
feat = self.lrelu(self.bn2_1(
self.conv2_1(feat))) # output spatial size: (16, 16)
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
feat = self.lrelu(self.bn3_1(
self.conv3_1(feat))) # output spatial size: (8, 8)
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
feat = self.lrelu(self.bn4_1(
self.conv4_1(feat))) # output spatial size: (4, 4)
feat = feat.reshape([feat.shape[0], -1])
feat = self.lrelu(self.linear1(feat))
out = self.linear2(feat)
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .sr_model import BaseSRModel
from .builder import MODELS
from .criterions import build_criterion
@MODELS.register()
class ESRGAN(BaseSRModel):
"""
This class implements the ESRGAN model.
ESRGAN paper: https://arxiv.org/pdf/1809.00219.pdf
"""
def __init__(self,
generator,
discriminator=None,
pixel_criterion=None,
perceptual_criterion=None,
gan_criterion=None):
"""Initialize the ESRGAN class.
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
pixel_criterion (dict): config of pixel criterion.
perceptual_criterion (dict): config of perceptual criterion.
gan_criterion (dict): config of gan criterion.
"""
super(ESRGAN, self).__init__(generator)
self.nets['generator'] = build_generator(generator)
if discriminator:
self.nets['discriminator'] = build_discriminator(discriminator)
if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)
if perceptual_criterion:
self.perceptual_criterion = build_criterion(perceptual_criterion)
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
def train_iter(self, optimizers=None):
self.set_requires_grad(self.nets['discriminator'], False)
optimizers['optimG'].clear_grad()
l_total = 0
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
# pixel loss
if self.pixel_criterion:
l_pix = self.pixel_criterion(self.output, self.gt)
l_total += l_pix
self.losses['loss_pix'] = l_pix
if self.perceptual_criterion:
l_g_percep, l_g_style = self.perceptual_criterion(
self.output, self.gt)
# l_total += l_pix
if l_g_percep is not None:
l_total += l_g_percep
self.losses['loss_percep'] = l_g_percep
if l_g_style is not None:
l_total += l_g_style
self.losses['loss_style'] = l_g_style
# gan loss (relativistic gan)
real_d_pred = self.nets['discriminator'](self.gt).detach()
fake_g_pred = self.nets['discriminator'](self.output)
l_g_real = self.gan_criterion(real_d_pred - paddle.mean(fake_g_pred),
False,
is_disc=False)
l_g_fake = self.gan_criterion(fake_g_pred - paddle.mean(real_d_pred),
True,
is_disc=False)
l_g_gan = (l_g_real + l_g_fake) / 2
l_total += l_g_gan
self.losses['l_g_gan'] = l_g_gan
l_total.backward()
optimizers['optimG'].step()
self.set_requires_grad(self.nets['discriminator'], True)
optimizers['optimD'].clear_grad()
# real
fake_d_pred = self.nets['discriminator'](self.output).detach()
real_d_pred = self.nets['discriminator'](self.gt)
l_d_real = self.gan_criterion(
real_d_pred - paddle.mean(fake_d_pred), True, is_disc=True) * 0.5
# fake
fake_d_pred = self.nets['discriminator'](self.output.detach())
l_d_fake = self.gan_criterion(
fake_d_pred - paddle.mean(real_d_pred.detach()),
False,
is_disc=True) * 0.5
(l_d_real + l_d_fake).backward()
optimizers['optimD'].step()
self.losses['l_d_real'] = l_d_real
self.losses['l_d_fake'] = l_d_fake
self.losses['out_d_real'] = paddle.mean(real_d_pred.detach())
self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach())
...@@ -19,7 +19,7 @@ from .base_model import BaseModel ...@@ -19,7 +19,7 @@ from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .losses import GANLoss from .criterions.gan_loss import GANLoss
from ..solver import build_optimizer from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
...@@ -82,7 +82,9 @@ class GANModel(BaseModel): ...@@ -82,7 +82,9 @@ class GANModel(BaseModel):
self.D_real_inputs = [paddle.to_tensor(input['img'])] self.D_real_inputs = [paddle.to_tensor(input['img'])]
if 'class_id' in input: # n class input if 'class_id' in input: # n class input
self.n_class = self.nets['netG'].n_class self.n_class = self.nets['netG'].n_class
self.D_real_inputs += [paddle.to_tensor(input['class_id'], dtype='int64')] self.D_real_inputs += [
paddle.to_tensor(input['class_id'], dtype='int64')
]
else: else:
self.n_class = 0 self.n_class = 0
...@@ -97,7 +99,9 @@ class GANModel(BaseModel): ...@@ -97,7 +99,9 @@ class GANModel(BaseModel):
rows_num = (batch_size - 1) // self.samples_every_row + 1 rows_num = (batch_size - 1) // self.samples_every_row + 1
class_ids = paddle.randint(0, self.n_class, [rows_num, 1]) class_ids = paddle.randint(0, self.n_class, [rows_num, 1])
class_ids = class_ids.tile([1, self.samples_every_row]) class_ids = class_ids.tile([1, self.samples_every_row])
class_ids = class_ids.reshape([-1,])[:batch_size].detach() class_ids = class_ids.reshape([
-1,
])[:batch_size].detach()
self.G_fixed_inputs[1] = class_ids.detach() self.G_fixed_inputs[1] = class_ids.detach()
def forward(self): def forward(self):
...@@ -105,7 +109,8 @@ class GANModel(BaseModel): ...@@ -105,7 +109,8 @@ class GANModel(BaseModel):
self.fake_imgs = self.nets['netG'](*self.G_inputs) # G(img, class_id) self.fake_imgs = self.nets['netG'](*self.G_inputs) # G(img, class_id)
# put items to visual dict # put items to visual dict
self.visual_items['fake_imgs'] = make_grid(self.fake_imgs, self.samples_every_row).detach() self.visual_items['fake_imgs'] = make_grid(
self.fake_imgs, self.samples_every_row).detach()
def backward_D(self): def backward_D(self):
"""Calculate GAN loss for the discriminator""" """Calculate GAN loss for the discriminator"""
...@@ -118,7 +123,8 @@ class GANModel(BaseModel): ...@@ -118,7 +123,8 @@ class GANModel(BaseModel):
pred_fake = self.nets['netD'](*self.D_fake_inputs) pred_fake = self.nets['netD'](*self.D_fake_inputs)
# Real # Real
real_imgs = self.D_real_inputs[0] real_imgs = self.D_real_inputs[0]
self.visual_items['real_imgs'] = make_grid(real_imgs, self.samples_every_row).detach() self.visual_items['real_imgs'] = make_grid(
real_imgs, self.samples_every_row).detach()
pred_real = self.nets['netD'](*self.D_real_inputs) pred_real = self.nets['netD'](*self.D_real_inputs)
self.loss_D_fake = self.criterionGAN(pred_fake, False, True) self.loss_D_fake = self.criterionGAN(pred_fake, False, True)
...@@ -126,7 +132,8 @@ class GANModel(BaseModel): ...@@ -126,7 +132,8 @@ class GANModel(BaseModel):
# combine loss and calculate gradients # combine loss and calculate gradients
if self.cfg.model.gan_mode in ['vanilla', 'lsgan']: if self.cfg.model.gan_mode in ['vanilla', 'lsgan']:
self.loss_D = self.loss_D + (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D = self.loss_D + (self.loss_D_fake +
self.loss_D_real) * 0.5
else: else:
self.loss_D = self.loss_D + self.loss_D_fake + self.loss_D_real self.loss_D = self.loss_D + self.loss_D_fake + self.loss_D_real
...@@ -179,7 +186,7 @@ class GANModel(BaseModel): ...@@ -179,7 +186,7 @@ class GANModel(BaseModel):
if self.step % self.visual_interval == 0: if self.step % self.visual_interval == 0:
with paddle.no_grad(): with paddle.no_grad():
self.visual_items['fixed_generated_imgs'] = make_grid( self.visual_items['fixed_generated_imgs'] = make_grid(
self.nets['netG'](*self.G_fixed_inputs), self.samples_every_row self.nets['netG'](*self.G_fixed_inputs),
) self.samples_every_row)
self.step += 1 self.step += 1
...@@ -15,8 +15,7 @@ import os ...@@ -15,8 +15,7 @@ import os
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import vgg16 from paddle.vision.models import vgg16
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from .base_model import BaseModel from .base_model import BaseModel
...@@ -24,12 +23,10 @@ from .base_model import BaseModel ...@@ -24,12 +23,10 @@ from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .losses import GANLoss from .criterions import build_criterion
from ..modules.init import init_weights from ..modules.init import init_weights
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool from ..utils.image_pool import ImagePool
from ..utils.preprocess import * from ..utils.preprocess import *
from ..datasets.makeup_dataset import MakeupDataset
VGGFACE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/vggface.pdparams' VGGFACE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/vggface.pdparams'
...@@ -39,18 +36,32 @@ class MakeupModel(BaseModel): ...@@ -39,18 +36,32 @@ class MakeupModel(BaseModel):
""" """
PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf
""" """
def __init__(self, cfg): def __init__(self,
generator,
discriminator=None,
cycle_criterion=None,
idt_criterion=None,
gan_criterion=None,
l1_criterion=None,
l2_criterion=None,
pool_size=50,
direction='a2b',
lambda_a=10.,
lambda_b=10.,
is_train=True):
"""Initialize the PSGAN class. """Initialize the PSGAN class.
Parameters: Parameters:
cfg (dict)-- config of model. cfg (dict)-- config of model.
""" """
super(MakeupModel, self).__init__(cfg) super(MakeupModel, self).__init__()
self.lambda_a = lambda_a
self.lambda_b = lambda_b
self.is_train = is_train
# define networks (both Generators and discriminators) # define networks (both Generators and discriminators)
# The naming is different from those used in the paper. # The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.nets['netG'] = build_generator(cfg.model.generator) self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0) init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)
if self.is_train: # define discriminators if self.is_train: # define discriminators
...@@ -61,46 +72,33 @@ class MakeupModel(BaseModel): ...@@ -61,46 +72,33 @@ class MakeupModel(BaseModel):
param = paddle.load(vgg_weight_path) param = paddle.load(vgg_weight_path)
vgg.load_dict(param) vgg.load_dict(param)
self.nets['netD_A'] = build_discriminator(cfg.model.discriminator) self.nets['netD_A'] = build_discriminator(discriminator)
self.nets['netD_B'] = build_discriminator(cfg.model.discriminator) self.nets['netD_B'] = build_discriminator(discriminator)
init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0) init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0)
init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0) init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0)
self.fake_A_pool = ImagePool( # create image buffer to store previously generated images
cfg.dataset.train.pool_size self.fake_A_pool = ImagePool(pool_size)
) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(pool_size)
self.fake_B_pool = ImagePool(
cfg.dataset.train.pool_size
) # create image buffer to store previously generated images
# define loss functions # define loss functions
self.criterionGAN = GANLoss( if gan_criterion:
cfg.model.gan_mode) #.to(self.device) # define GAN loss. self.gan_criterion = build_criterion(gan_criterion)
self.criterionCycle = paddle.nn.L1Loss() if cycle_criterion:
self.criterionIdt = paddle.nn.L1Loss() self.cycle_criterion = build_criterion(cycle_criterion)
self.criterionL1 = paddle.nn.L1Loss() if idt_criterion:
self.criterionL2 = paddle.nn.MSELoss() self.idt_criterion = build_criterion(idt_criterion)
if l1_criterion:
self.build_lr_scheduler() self.l1_criterion = build_criterion(l1_criterion)
self.optimizers['optimizer_G'] = build_optimizer( if l2_criterion:
cfg.optimizer, self.l2_criterion = build_criterion(l2_criterion)
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters()) def setup_input(self, input):
self.optimizers['optimizer_DA'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD_A'].parameters())
self.optimizers['optimizer_DB'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD_B'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters: Args:
input (dict): include the data itself and its metadata information. input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
""" """
self.real_A = paddle.to_tensor(input['image_A']) self.real_A = paddle.to_tensor(input['image_A'])
self.real_B = paddle.to_tensor(input['image_B']) self.real_B = paddle.to_tensor(input['image_B'])
...@@ -143,24 +141,6 @@ class MakeupModel(BaseModel): ...@@ -143,24 +141,6 @@ class MakeupModel(BaseModel):
self.visual_items['fake_A'] = self.fake_A self.visual_items['fake_A'] = self.fake_A
self.visual_items['rec_B'] = self.rec_B self.visual_items['rec_B'] = self.rec_B
def forward_test(self, input):
'''
not implement now
'''
return self.nets['netG'](input['image_A'], input['image_B'],
input['P_A'], input['P_B'],
input['consis_mask'], input['mask_A_aug'],
input['mask_B_aug'])
def test(self, input):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.no_grad():
return self.forward_test(input)
def backward_D_basic(self, netD, real, fake): def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator """Calculate GAN loss for the discriminator
...@@ -174,10 +154,10 @@ class MakeupModel(BaseModel): ...@@ -174,10 +154,10 @@ class MakeupModel(BaseModel):
""" """
# Real # Real
pred_real = netD(real) pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True) loss_D_real = self.gan_criterion(pred_real, True)
# Fake # Fake
pred_fake = netD(fake.detach()) pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False) loss_D_fake = self.gan_criterion(pred_fake, False)
# Combined loss and calculate gradients # Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward() loss_D.backward()
...@@ -200,24 +180,24 @@ class MakeupModel(BaseModel): ...@@ -200,24 +180,24 @@ class MakeupModel(BaseModel):
def backward_G(self): def backward_G(self):
"""Calculate the loss for generators G_A and G_B""" """Calculate the loss for generators G_A and G_B"""
lambda_idt = self.cfg.lambda_identity lambda_A = self.lambda_a
lambda_A = self.cfg.lambda_A lambda_B = self.lambda_b
lambda_B = self.cfg.lambda_B
lambda_vgg = 5e-3 lambda_vgg = 5e-3
# Identity loss # Identity loss
if lambda_idt > 0: if self.idt_criterion:
self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A, self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A,
self.P_A, self.P_A, self.P_A, self.P_A,
self.c_m_idt_a, self.mask_A_aug, self.c_m_idt_a, self.mask_A_aug,
self.mask_B_aug) # G_A(A) self.mask_B_aug) # G_A(A)
self.loss_idt_A = self.criterionIdt( self.loss_idt_A = self.idt_criterion(self.idt_A,
self.idt_A, self.real_A) * lambda_A * lambda_idt self.real_A) * lambda_A
self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B, self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B,
self.P_B, self.P_B, self.P_B, self.P_B,
self.c_m_idt_b, self.mask_A_aug, self.c_m_idt_b, self.mask_A_aug,
self.mask_B_aug) # G_A(A) self.mask_B_aug) # G_A(A)
self.loss_idt_B = self.criterionIdt( self.loss_idt_B = self.idt_criterion(self.idt_B,
self.idt_B, self.real_B) * lambda_B * lambda_idt self.real_B) * lambda_B
# visual # visual
self.visual_items['idt_A'] = self.idt_A self.visual_items['idt_A'] = self.idt_A
...@@ -227,16 +207,16 @@ class MakeupModel(BaseModel): ...@@ -227,16 +207,16 @@ class MakeupModel(BaseModel):
self.loss_idt_B = 0 self.loss_idt_B = 0
# GAN loss D_A(G_A(A)) # GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_A), self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_A),
True) True)
# GAN loss D_B(G_B(B)) # GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_B), self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_B),
True) True)
# Forward cycle loss || G_B(G_A(A)) - A|| # Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.loss_cycle_A = self.cycle_criterion(self.rec_A,
self.real_A) * lambda_A self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B|| # Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.loss_cycle_B = self.cycle_criterion(self.rec_B,
self.real_B) * lambda_B self.real_B) * lambda_B
self.losses['G_A_adv_loss'] = self.loss_G_A self.losses['G_A_adv_loss'] = self.loss_G_A
...@@ -270,8 +250,10 @@ class MakeupModel(BaseModel): ...@@ -270,8 +250,10 @@ class MakeupModel(BaseModel):
fake_match_lip_B = fake_match_lip_B.unsqueeze(0) fake_match_lip_B = fake_match_lip_B.unsqueeze(0)
fake_A_lip_masked = fake_A * mask_A_lip fake_A_lip_masked = fake_A * mask_A_lip
fake_B_lip_masked = fake_B * mask_B_lip fake_B_lip_masked = fake_B * mask_B_lip
g_A_lip_loss_his = self.criterionL1(fake_A_lip_masked, fake_match_lip_A) g_A_lip_loss_his = self.l1_criterion(fake_A_lip_masked,
g_B_lip_loss_his = self.criterionL1(fake_B_lip_masked, fake_match_lip_B) fake_match_lip_A)
g_B_lip_loss_his = self.l1_criterion(fake_B_lip_masked,
fake_match_lip_B)
#skin #skin
mask_A_skin = self.mask_A_aug[:, 1].unsqueeze(1) mask_A_skin = self.mask_A_aug[:, 1].unsqueeze(1)
...@@ -294,9 +276,9 @@ class MakeupModel(BaseModel): ...@@ -294,9 +276,9 @@ class MakeupModel(BaseModel):
fake_match_skin_B = fake_match_skin_B.unsqueeze(0) fake_match_skin_B = fake_match_skin_B.unsqueeze(0)
fake_A_skin_masked = fake_A * mask_A_skin fake_A_skin_masked = fake_A * mask_A_skin
fake_B_skin_masked = fake_B * mask_B_skin fake_B_skin_masked = fake_B * mask_B_skin
g_A_skin_loss_his = self.criterionL1(fake_A_skin_masked, g_A_skin_loss_his = self.l1_criterion(fake_A_skin_masked,
fake_match_skin_A) fake_match_skin_A)
g_B_skin_loss_his = self.criterionL1(fake_B_skin_masked, g_B_skin_loss_his = self.l1_criterion(fake_B_skin_masked,
fake_match_skin_B) fake_match_skin_B)
#eye #eye
...@@ -320,8 +302,10 @@ class MakeupModel(BaseModel): ...@@ -320,8 +302,10 @@ class MakeupModel(BaseModel):
fake_match_eye_B = fake_match_eye_B.unsqueeze(0) fake_match_eye_B = fake_match_eye_B.unsqueeze(0)
fake_A_eye_masked = fake_A * mask_A_eye fake_A_eye_masked = fake_A * mask_A_eye
fake_B_eye_masked = fake_B * mask_B_eye fake_B_eye_masked = fake_B * mask_B_eye
g_A_eye_loss_his = self.criterionL1(fake_A_eye_masked, fake_match_eye_A) g_A_eye_loss_his = self.l1_criterion(fake_A_eye_masked,
g_B_eye_loss_his = self.criterionL1(fake_B_eye_masked, fake_match_eye_B) fake_match_eye_A)
g_B_eye_loss_his = self.l1_criterion(fake_B_eye_masked,
fake_match_eye_B)
self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his + self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his +
g_A_skin_loss_his * 0.1) * 0.1 g_A_skin_loss_his * 0.1) * 0.1
...@@ -335,13 +319,13 @@ class MakeupModel(BaseModel): ...@@ -335,13 +319,13 @@ class MakeupModel(BaseModel):
vgg_s = self.vgg(self.real_A) vgg_s = self.vgg(self.real_A)
vgg_s.stop_gradient = True vgg_s.stop_gradient = True
vgg_fake_A = self.vgg(self.fake_A) vgg_fake_A = self.vgg(self.fake_A)
self.loss_A_vgg = self.criterionL2(vgg_fake_A, self.loss_A_vgg = self.l2_criterion(vgg_fake_A,
vgg_s) * lambda_A * lambda_vgg vgg_s) * lambda_A * lambda_vgg
vgg_r = self.vgg(self.real_B) vgg_r = self.vgg(self.real_B)
vgg_r.stop_gradient = True vgg_r.stop_gradient = True
vgg_fake_B = self.vgg(self.fake_B) vgg_fake_B = self.vgg(self.fake_B)
self.loss_B_vgg = self.criterionL2(vgg_fake_B, self.loss_B_vgg = self.l2_criterion(vgg_fake_B,
vgg_r) * lambda_B * lambda_vgg vgg_r) * lambda_B * lambda_vgg
self.loss_rec = (self.loss_cycle_A * 0.2 + self.loss_cycle_B * 0.2 + self.loss_rec = (self.loss_cycle_A * 0.2 + self.loss_cycle_B * 0.2 +
...@@ -359,15 +343,14 @@ class MakeupModel(BaseModel): ...@@ -359,15 +343,14 @@ class MakeupModel(BaseModel):
(self.mask_A == 10), dtype='float32') + paddle.cast( (self.mask_A == 10), dtype='float32') + paddle.cast(
(self.mask_A == 8), dtype='float32') (self.mask_A == 8), dtype='float32')
mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1) mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1)
self.loss_G_bg_consis = self.criterionL1( self.loss_G_bg_consis = self.l1_criterion(
self.real_A * mask_A_consis, self.fake_A * mask_A_consis) * 0.1 self.real_A * mask_A_consis, self.fake_A * mask_A_consis) * 0.1
# combined loss and calculate gradients # combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_rec + self.loss_idt + self.loss_G_A_his + self.loss_G_B_his + self.loss_G_bg_consis self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_rec + self.loss_idt + self.loss_G_A_his + self.loss_G_B_his + self.loss_G_bg_consis
self.loss_G.backward() self.loss_G.backward()
def optimize_parameters(self): def train_iter(self, optimizers=None):
"""Calculate losses, gradients, and update network weights; called in every training iteration""" """Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward # forward
self.forward() # compute fake images and reconstruction images. self.forward() # compute fake images and reconstruction images.
...@@ -392,5 +375,4 @@ class MakeupModel(BaseModel): ...@@ -392,5 +375,4 @@ class MakeupModel(BaseModel):
self.backward_D_B() # calculate graidents for D_B self.backward_D_B() # calculate graidents for D_B
self.optimizers['optimizer_DB'].minimize( self.optimizers['optimizer_DB'].minimize(
self.loss_D_B) #step() # update D_A and D_B's weights self.loss_D_B) #step() # update D_A and D_B's weights
self.optimizers['optimizer_DB'].clear_gradients( self.optimizers['optimizer_DB'].clear_gradients()
) #zero_grad() # set D_A and D_B's gradients to zero
...@@ -18,7 +18,7 @@ from .base_model import BaseModel ...@@ -18,7 +18,7 @@ from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .losses import GANLoss from .criterions import build_criterion
from ..solver import build_optimizer from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
...@@ -29,63 +29,57 @@ from ..utils.image_pool import ImagePool ...@@ -29,63 +29,57 @@ from ..utils.image_pool import ImagePool
class Pix2PixModel(BaseModel): class Pix2PixModel(BaseModel):
""" This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
The model training requires 'paired' dataset.
By default, it uses a '--netG unet256' U-Net generator,
a '--netD basic' discriminator (from PatchGAN),
and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
""" """
def __init__(self, cfg): def __init__(self,
generator,
discriminator=None,
pixel_criterion=None,
gan_criterion=None,
direction='a2b'):
"""Initialize the pix2pix class. """Initialize the pix2pix class.
Parameters: Args:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict generator (dict): config of generator.
discriminator (dict): config of discriminator.
pixel_criterion (dict): config of pixel criterion.
gan_criterion (dict): config of gan criterion.
""" """
super(Pix2PixModel, self).__init__(cfg) super(Pix2PixModel, self).__init__()
self.direction = direction
# define networks (both generator and discriminator) # define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator) self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG']) init_weights(self.nets['netG'])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.is_train: if discriminator:
self.nets['netD'] = build_discriminator(cfg.model.discriminator) self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD']) init_weights(self.nets['netD'])
if self.is_train: if pixel_criterion:
self.losses = {} self.pixel_criterion = build_criterion(pixel_criterion)
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode) if gan_criterion:
self.criterionL1 = paddle.nn.L1Loss() self.gan_criterion = build_criterion(gan_criterion)
# build optimizers def setup_input(self, input):
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters: Args:
input (dict): include the data itself and its metadata information. input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap images in domain A and domain B. The option 'direction' can be used to swap images in domain A and domain B.
""" """
AtoB = self.cfg.dataset.train.direction == 'AtoB' AtoB = self.direction == 'AtoB'
self.real_A = paddle.fluid.dygraph.to_variable( self.real_A = paddle.fluid.dygraph.to_variable(
input['A' if AtoB else 'B']) input['A' if AtoB else 'B'])
self.real_B = paddle.fluid.dygraph.to_variable( self.real_B = paddle.fluid.dygraph.to_variable(
input['B' if AtoB else 'A']) input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_path' if AtoB else 'B_path']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
...@@ -102,11 +96,11 @@ class Pix2PixModel(BaseModel): ...@@ -102,11 +96,11 @@ class Pix2PixModel(BaseModel):
# use conditional GANs; we need to feed both input and output to the discriminator # use conditional GANs; we need to feed both input and output to the discriminator
fake_AB = paddle.concat((self.real_A, self.fake_B), 1) fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
pred_fake = self.nets['netD'](fake_AB.detach()) pred_fake = self.nets['netD'](fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False) self.loss_D_fake = self.gan_criterion(pred_fake, False)
# Real # Real
real_AB = paddle.concat((self.real_A, self.real_B), 1) real_AB = paddle.concat((self.real_A, self.real_B), 1)
pred_real = self.nets['netD'](real_AB) pred_real = self.nets['netD'](real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True) self.loss_D_real = self.gan_criterion(pred_real, True)
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
...@@ -120,10 +114,9 @@ class Pix2PixModel(BaseModel): ...@@ -120,10 +114,9 @@ class Pix2PixModel(BaseModel):
# First, G(A) should fake the discriminator # First, G(A) should fake the discriminator
fake_AB = paddle.concat((self.real_A, self.fake_B), 1) fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
pred_fake = self.nets['netD'](fake_AB) pred_fake = self.nets['netD'](fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G_GAN = self.gan_criterion(pred_fake, True)
# Second, G(A) = B # Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.loss_G_L1 = self.pixel_criterion(self.fake_B, self.real_B)
self.real_B) * self.cfg.lambda_L1
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G = self.loss_G_GAN + self.loss_G_L1
...@@ -133,18 +126,18 @@ class Pix2PixModel(BaseModel): ...@@ -133,18 +126,18 @@ class Pix2PixModel(BaseModel):
self.losses['G_adv_loss'] = self.loss_G_GAN self.losses['G_adv_loss'] = self.loss_G_GAN
self.losses['G_L1_loss'] = self.loss_G_L1 self.losses['G_L1_loss'] = self.loss_G_L1
def optimize_parameters(self): def train_iter(self, optimizers=None):
# compute fake images: G(A) # compute fake images: G(A)
self.forward() self.forward()
# update D # update D
self.set_requires_grad(self.nets['netD'], True) self.set_requires_grad(self.nets['netD'], True)
self.optimizers['optimizer_D'].clear_grad() optimizers['optimD'].clear_grad()
self.backward_D() self.backward_D()
self.optimizers['optimizer_D'].step() optimizers['optimD'].step()
# update G # update G
self.set_requires_grad(self.nets['netD'], False) self.set_requires_grad(self.nets['netD'], False)
self.optimizers['optimizer_G'].clear_grad() optimizers['optimG'].clear_grad()
self.backward_G() self.backward_G()
self.optimizers['optimizer_G'].step() optimizers['optimG'].step()
...@@ -12,75 +12,69 @@ ...@@ -12,75 +12,69 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .criterions.builder import build_criterion
from ..solver import build_optimizer
from .base_model import BaseModel from .base_model import BaseModel
from .losses import GANLoss
from .builder import MODELS
import importlib
from collections import OrderedDict
from copy import deepcopy
from os import path as osp
from .builder import MODELS from .builder import MODELS
from ..utils.visual import tensor2img
@MODELS.register() @MODELS.register()
class SRModel(BaseModel): class BaseSRModel(BaseModel):
"""Base SR model for single image super-resolution.""" """Base SR model for single image super-resolution.
def __init__(self, cfg): """
super(SRModel, self).__init__(cfg) def __init__(self, generator, pixel_criterion=None):
"""
self.model_names = ['G'] Args:
generator (dict): config of generator.
self.netG = build_generator(cfg.model.generator) pixel_criterion (dict): config of pixel criterion.
self.visual_names = ['lq', 'output', 'gt'] """
super(BaseSRModel, self).__init__()
self.loss_names = ['l_total']
self.optimizers = [] self.nets['generator'] = build_generator(generator)
if self.is_train:
self.criterionL1 = paddle.nn.L1Loss()
self.build_lr_scheduler() if pixel_criterion:
self.optimizer_G = build_optimizer( self.pixel_criterion = build_criterion(pixel_criterion)
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netG.parameters())
self.optimizers.append(self.optimizer_G)
def set_input(self, input): def setup_input(self, input):
self.lq = paddle.to_tensor(input['lq']) self.lq = paddle.fluid.dygraph.to_variable(input['lq'])
self.visual_items['lq'] = self.lq
if 'gt' in input: if 'gt' in input:
self.gt = paddle.to_tensor(input['gt']) self.gt = paddle.fluid.dygraph.to_variable(input['gt'])
self.visual_items['gt'] = self.gt
self.image_paths = input['lq_path'] self.image_paths = input['lq_path']
def forward(self): def forward(self):
pass pass
def test(self): def train_iter(self, optims=None):
"""Forward function used in test time. optims['optim'].clear_grad()
"""
with paddle.no_grad():
self.output = self.netG(self.lq)
def optimize_parameters(self):
self.optimizer_G.clear_grad()
self.output = self.netG(self.lq)
l_total = 0 self.output = self.nets['generator'](self.lq)
loss_dict = OrderedDict() self.visual_items['output'] = self.output
# pixel loss # pixel loss
if self.criterionL1: loss_pixel = self.pixel_criterion(self.output, self.gt)
l_pix = self.criterionL1(self.output, self.gt) self.losses['loss_pixel'] = loss_pixel
l_total += l_pix
loss_dict['l_pix'] = l_pix loss_pixel.backward()
optims['optim'].step()
l_total.backward() def test_iter(self, metrics=None):
self.loss_l_total = l_total self.nets['generator'].eval()
self.optimizer_G.step() with paddle.no_grad():
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.gt):
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
此差异已折叠。
...@@ -19,8 +19,6 @@ import paddle ...@@ -19,8 +19,6 @@ import paddle
from ..utils.logger import get_logger from ..utils.logger import get_logger
logger = get_logger('init')
def _calculate_fan_in_and_fan_out(tensor): def _calculate_fan_in_and_fan_out(tensor):
dimensions = len(tensor.shape) dimensions = len(tensor.shape)
...@@ -313,5 +311,6 @@ def init_weights(net, init_type='normal', init_gain=0.02): ...@@ -313,5 +311,6 @@ def init_weights(net, init_type='normal', init_gain=0.02):
normal_(m.weight, 1.0, init_gain) normal_(m.weight, 1.0, init_gain)
constant_(m.bias, 0.0) constant_(m.bias, 0.0)
logger = get_logger()
logger.debug('initialize network with %s' % init_type) logger.debug('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func> net.apply(init_func) # apply the initialization function <init_func>
...@@ -12,4 +12,7 @@ ...@@ -12,4 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .optimizer import build_optimizer from .lr_scheduler import CosineAnnealingRestartLR, LinearDecay
from .optimizer import *
from .builder import build_lr_scheduler
from .builder import build_optimizer
此差异已折叠。
...@@ -15,14 +15,9 @@ ...@@ -15,14 +15,9 @@
import copy import copy
import paddle import paddle
from .lr_scheduler import build_lr_scheduler from .builder import OPTIMIZERS
OPTIMIZERS.register(paddle.optimizer.Adam)
def build_optimizer(cfg, lr_scheduler, parameter_list=None): OPTIMIZERS.register(paddle.optimizer.SGD)
cfg_copy = copy.deepcopy(cfg) OPTIMIZERS.register(paddle.optimizer.Momentum)
OPTIMIZERS.register(paddle.optimizer.RMSProp)
opt_name = cfg_copy.pop('name')
return getattr(paddle.optimizer, opt_name)(lr_scheduler,
parameters=parameter_list,
**cfg_copy)
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册