未验证 提交 d76c74e4 编写于 作者: L LielinJiang 提交者: GitHub

Add esrgan model and refine codes (#104)

* add esrgan model, refine codes
上级 05483bee
epochs: 200
output_dir: output_dir
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model:
name: CycleGANModel
......@@ -20,60 +17,96 @@ model:
n_layers: 3
norm_type: instance
input_nc: 3
gan_mode: lsgan
cycle_criterion:
name: L1Loss
idt_criterion:
name: L1Loss
loss_weight: 0.5
gan_criterion:
name: GANLoss
gan_mode: lsgan
dataset:
train:
name: UnpairedDataset
dataroot: data/cityscapes
dataroot_a: data/cityscapes/trainA
dataroot_b: data/cityscapes/trainB
num_workers: 0
batch_size: 1
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transforms:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop
size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
is_train: True
max_size: inf
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: RandomCrop
size: [256, 256]
keys: ['image', 'image']
- name: RandomHorizontalFlip
prob: 0.5
keys: ['image', 'image']
- name: Transpose
keys: ['image', 'image']
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
test:
name: SingleDataset
dataroot: data/cityscapes/testB
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transforms:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
name: UnpairedDataset
dataroot_a: data/cityscapes/testA
dataroot_b: data/cityscapes/testB
num_workers: 0
batch_size: 1
max_size: inf
is_train: False
load_pipeline:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: Transpose
keys: ['image', 'image']
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG_A
- netG_B
beta1: 0.5
optimD:
name: Adam
net_names:
- netD_A
- netD_B
beta1: 0.5
log_config:
interval: 100
......
epochs: 200
output_dir: output_dir
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model:
name: CycleGANModel
......@@ -20,60 +17,96 @@ model:
n_layers: 3
norm_type: instance
input_nc: 3
gan_mode: lsgan
cycle_criterion:
name: L1Loss
idt_criterion:
name: L1Loss
loss_weight: 0.5
gan_criterion:
name: GANLoss
gan_mode: lsgan
dataset:
train:
name: UnpairedDataset
dataroot: data/horse2zebra
dataroot_a: data/horse2zebra/trainA
dataroot_b: data/horse2zebra/trainB
num_workers: 0
batch_size: 1
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transforms:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop
size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
is_train: True
max_size: inf
load_pipeline:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: RandomCrop
size: [256, 256]
keys: ['image', 'image']
- name: RandomHorizontalFlip
prob: 0.5
keys: ['image', 'image']
- name: Transpose
keys: ['image', 'image']
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
test:
name: SingleDataset
dataroot: data/horse2zebra/testA
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transforms:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
name: UnpairedDataset
dataroot_a: data/horse2zebra/testA
dataroot_b: data/horse2zebra/testB
num_workers: 0
batch_size: 1
max_size: inf
is_train: False
load_pipeline:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: Transpose
keys: ['image', 'image']
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG_A
- netG_B
beta1: 0.5
optimD:
name: Adam
net_names:
- netD_A
- netD_B
beta1: 0.5
log_config:
interval: 100
......
......@@ -15,52 +15,70 @@ model:
norm_type: batch
ndf: 64
input_nc: 1
gan_mode: vanilla #wgangp
gan_criterion:
name: GANLoss
gan_mode: vanilla
dataset:
train:
name: SingleDataset
dataroot: data/mnist/train
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 1
output_nc: 1
batch_size: 128
serial_batches: False
transforms:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
preprocess:
- name: LoadImageFromFile
key: A
- name: Transfroms
input_keys: [A]
pipeline:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test:
name: SingleDataset
dataroot: data/mnist/test
max_dataset_size: inf
input_nc: 1
output_nc: 1
serial_batches: False
transforms:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
lr_scheduler:
name: linear
learning_rate: 0.00002
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
log_config:
interval: 100
......
total_iters: 1000000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: BaseSRModel
generator:
name: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
pixel_criterion:
name: L1Loss
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers: 4
batch_size: 16
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 128
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: 0.0002
periods: [250000, 250000, 250000, 250000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: false
ssim:
name: SSIM
crop_border: 4
test_y_channel: false
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
total_iters: 250000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: ESRGAN
generator:
name: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
discriminator:
name: VGGDiscriminator128
in_channels: 3
num_feat: 64
pixel_criterion:
name: L1Loss
loss_weight: !!float 1e-2
perceptual_criterion:
name: PerceptualLoss
layer_weights:
'34': 1.0
perceptual_weight: 1.0
style_weight: 0.0
norm_img: False
gan_criterion:
name: GANLoss
gan_mode: vanilla
loss_weight: !!float 5e-3
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers: 6
batch_size: 32
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 128
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.0001
milestones: [50000, 100000, 200000, 300000]
gamma: 0.5
optimizer:
optimG:
name: Adam
net_names:
- generator
weight_decay: 0.0
beta1: 0.9
beta2: 0.99
optimD:
name: Adam
net_names:
- discriminator
weight_decay: 0.0
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: false
ssim:
name: SSIM
crop_border: 4
test_y_channel: false
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5000
epochs: 100
output_dir: tmp
checkpoints_dir: checkpoints
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model:
name: MakeupModel
......@@ -17,7 +14,18 @@ model:
n_layers: 3
input_nc: 3
norm_type: spectral
gan_mode: lsgan
cycle_criterion:
name: L1Loss
idt_criterion:
name: L1Loss
loss_weight: 0.5
l1_criterion:
name: L1Loss
l2_criterion:
name: MSELoss
gan_criterion:
name: GANLoss
gan_mode: lsgan
dataset:
train:
......@@ -26,28 +34,42 @@ dataset:
dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup]
phase: train
pool_size: 16
test:
name: MakeupDataset
trans_size: 256
dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup]
phase: test
pool_size: 16
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_DA:
name: Adam
net_names:
- netD_A
beta1: 0.5
optimizer_DB:
name: Adam
net_names:
- netD_B
beta1: 0.5
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 1
interval: 5
epochs: 200
output_dir: output_dir
lambda_L1: 100
model:
name: Pix2PixModel
......@@ -18,70 +17,89 @@ model:
n_layers: 3
input_nc: 6
norm_type: batch
gan_mode: vanilla
direction: b2a
pixel_criterion:
name: L1Loss
loss_weight: 100
gan_criterion:
name: GANLoss
gan_mode: vanilla
dataset:
train:
name: PairedDataset
dataroot: data/cityscapes
dataroot: data/cityscapes/train
num_workers: 4
batch_size: 1
phase: train
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 0
transforms:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
preprocess:
- name: LoadImageFromFile
key: pair
- name: SplitPairedImage
key: pair
paired_keys: [A, B]
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test:
name: PairedDataset
dataroot: data/cityscapes/
phase: test
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: True
pool_size: 50
transforms:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer:
name: Adam
beta1: 0.5
dataroot: data/cityscapes/test
num_workers: 4
batch_size: 1
load_pipeline:
- name: LoadImageFromFile
key: pair
- name: SplitPairedImage
key: pair
paired_keys: [A, B]
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG
beta1: 0.5
optimD:
name: Adam
net_names:
- netD
beta1: 0.5
log_config:
interval: 100
......
epochs: 200
output_dir: output_dir
lambda_L1: 100
model:
name: Pix2PixModel
......@@ -18,69 +17,86 @@ model:
n_layers: 3
input_nc: 6
norm_type: batch
gan_mode: vanilla
direction: b2a
pixel_criterion:
name: L1Loss
loss_weight: 100
gan_criterion:
name: GANLoss
gan_mode: vanilla
dataset:
train:
name: PairedDataset
dataroot: data/cityscapes
num_workers: 0
dataroot: data/cityscapes/train
num_workers: 4
batch_size: 1
phase: train
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 0
transforms:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
preprocess:
- name: LoadImageFromFile
key: pair
- name: SplitPairedImage
key: pair
paired_keys: [A, B]
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test:
name: PairedDataset
dataroot: data/cityscapes/
phase: test
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: True
pool_size: 50
transforms:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer:
name: Adam
beta1: 0.5
dataroot: data/cityscapes/test
num_workers: 4
batch_size: 1
load_pipeline:
- name: LoadImageFromFile
key: pair
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0004
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG
beta1: 0.5
optimD:
name: Adam
net_names:
- netD
beta1: 0.5
log_config:
interval: 100
......
epochs: 200
output_dir: output_dir
lambda_L1: 100
model:
name: Pix2PixModel
......@@ -18,69 +17,86 @@ model:
n_layers: 3
input_nc: 6
norm_type: batch
gan_mode: vanilla
direction: b2a
pixel_criterion:
name: L1Loss
loss_weight: 100
gan_criterion:
name: GANLoss
gan_mode: vanilla
dataset:
train:
name: PairedDataset
dataroot: data/facades/
num_workers: 0
dataroot: data/facades/train
num_workers: 4
batch_size: 1
phase: train
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 0
transforms:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
preprocess:
- name: LoadImageFromFile
key: pair
- name: SplitPairedImage
key: pair
paired_keys: [A, B]
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test:
name: PairedDataset
dataroot: data/facades/
phase: test
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: True
pool_size: 50
transforms:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
optimizer:
name: Adam
beta1: 0.5
dataroot: data/facades/test
num_workers: 4
batch_size: 1
load_pipeline:
- name: LoadImageFromFile
key: pair
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- netG
beta1: 0.5
optimD:
name: Adam
net_names:
- netD
beta1: 0.5
log_config:
interval: 100
......
epochs: 300
output_dir: output_dir
adv_weight: 1.0
cycle_weight: 10.0
identity_weight: 10.0
cam_weight: 1000.0
model:
name: UGATITModel
......@@ -25,57 +21,102 @@ model:
input_nc: 3
ndf: 64
n_layers: 5
l1_criterion:
name: L1Loss
mse_criterion:
name: MSELoss
bce_criterion:
name: BCEWithLogitsLoss
adv_weight: 1.0
cycle_weight: 10.0
identity_weight: 10.0
cam_weight: 1000.0
dataset:
train:
name: UnpairedDataset
dataroot: data/selfie2anime
dataroot_a: data/selfie2anime/trainA
dataroot_b: data/selfie2anime/trainB
num_workers: 0
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
transforms:
- name: Resize
size: [286, 286]
interpolation: 'bilinear' #'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop
size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
batch_size: 1
is_train: True
max_size: inf
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: RandomCrop
size: [256, 256]
keys: ['image', 'image']
- name: RandomHorizontalFlip
prob: 0.5
keys: ['image', 'image']
- name: Transpose
keys: ['image', 'image']
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
test:
name: SingleDataset
dataroot: data/selfie2anime/testA
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
transforms:
- name: Resize
size: [256, 256]
interpolation: 'bilinear' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
weight_decay: 0.0001
name: UnpairedDataset
dataroot_a: data/selfie2anime/testA
dataroot_b: data/selfie2anime/testB
num_workers: 0
batch_size: 1
max_size: inf
is_train: False
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
key: B
- name: Transfroms
input_keys: [A, B]
pipeline:
- name: Resize
size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: ['image', 'image']
- name: Transpose
keys: ['image', 'image']
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: ['image', 'image']
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0001
start_epoch: 150
decay_epochs: 150
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimG:
name: Adam
net_names:
- genA2B
- genB2A
weight_decay: 0.0001
beta1: 0.5
optimD:
name: Adam
net_names:
- disGA
- disGB
- disLA
- disLB
weight_decay: 0.0001
beta1: 0.5
log_config:
interval: 10
......
......@@ -15,7 +15,7 @@
from .unpaired_dataset import UnpairedDataset
from .single_dataset import SingleDataset
from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset
from .base_sr_dataset import SRDataset
from .makeup_dataset import MakeupDataset
from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset
......@@ -12,105 +12,124 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import random
import numpy as np
import os
from pathlib import Path
from abc import ABCMeta, abstractmethod
from paddle.io import Dataset
from PIL import Image
import cv2
import paddle.vision.transforms as transforms
from .transforms import transforms as T
from abc import ABC, abstractmethod
from .preprocess import build_preprocess
IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
'.PPM', '.bmp', '.BMP')
class BaseDataset(Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
def scandir(dir_path, suffix=None, recursive=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str | obj:`Path`): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""
def __init__(self, cfg):
"""Initialize the class; save the options in the class
if isinstance(dir_path, (str, Path)):
dir_path = str(dir_path)
else:
raise TypeError('"dir_path" must be a string or Path object')
Args:
cfg (dict) -- stores all the experiment flags
"""
self.cfg = cfg
self.root = cfg.dataroot
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = os.path.relpath(entry.path, root)
if suffix is None:
yield rel_path
elif rel_path.endswith(suffix):
yield rel_path
else:
if recursive:
yield from _scandir(entry.path,
suffix=suffix,
recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
class BaseDataset(Dataset, metaclass=ABCMeta):
"""Base class for datasets.
All datasets should subclass it.
All subclasses should overwrite:
``prepare_data_infos``, supporting to load information and generate
image lists.
Args:
preprocess (list[dict]): A sequence of data preprocess config.
"""
def __init__(self, preprocess=None):
super(BaseDataset, self).__init__()
if preprocess:
self.preprocess = build_preprocess(preprocess)
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
def prepare_data_infos(self):
"""Abstract function for loading annotation.
All subclasses should overwrite this function
should set self.annotations in this fucntion
data_infos should be as list of dict:
[{key_path: file_path}, {key_path: file_path}, {key_path: file_path}]
"""
self.data_infos = None
@staticmethod
def scan_folder(path):
"""Obtain sample path list (including sub-folders) from a given folder.
Parameters:
index - - a random integer for data indexing
Args:
path (str|pathlib.Path): Folder path.
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
list[str]: sample list obtained form given folder.
"""
pass
def get_params(cfg, size):
w, h = size
new_h = h
new_w = w
if cfg.preprocess == 'resize_and_crop':
new_h = new_w = cfg.load_size
elif cfg.preprocess == 'scale_width_and_crop':
new_w = cfg.load_size
new_h = cfg.load_size * h // w
x = random.randint(0, np.maximum(0, new_w - cfg.crop_size))
y = random.randint(0, np.maximum(0, new_h - cfg.crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(cfg,
params=None,
grayscale=False,
method=cv2.INTER_CUBIC,
convert=True):
transform_list = []
if grayscale:
print('grayscale not support for now!!!')
pass
if 'resize' in cfg.preprocess:
osize = (cfg.load_size, cfg.load_size)
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in cfg.preprocess:
print('scale_width not support for now!!!')
pass
if 'crop' in cfg.preprocess:
if params is None:
transform_list.append(T.RandomCrop(cfg.crop_size))
if isinstance(path, (str, Path)):
path = str(path)
else:
transform_list.append(T.Crop(params['crop_pos'], cfg.crop_size))
if cfg.preprocess == 'none':
print('preprocess not support for now!!!')
pass
if not cfg.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.RandomHorizontalFlip(1.0))
if convert:
transform_list += [transforms.Permute(to_rgb=True)]
if cfg.get('normalize', None):
transform_list += [
transforms.Normalize(cfg.normalize.mean, cfg.normalize.std)
]
return transforms.Compose(transform_list)
raise TypeError("'path' must be a str or a Path object, "
f'but received {type(path)}.')
samples = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True))
samples = [os.path.join(path, v) for v in samples]
assert samples, '{} has no valid image file.'.format(path)
return samples
def __getitem__(self, idx):
datas = self.data_infos[idx]
if hasattr(self, 'preprocess') and self.preprocess:
datas = self.preprocess(datas)
return datas
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return len(self.data_infos)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
from pathlib import Path
from .base_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register()
class SRDataset(BaseDataset):
"""Base super resulotion dataset for image restoration."""
def __init__(self,
lq_folder,
gt_folder,
preprocess,
scale,
filename_tmpl='{}'):
super(SRDataset, self).__init__(preprocess)
self.lq_folder = lq_folder
self.gt_folder = gt_folder
self.scale = scale
self.filename_tmpl = filename_tmpl
self.prepare_data_infos()
def prepare_data_infos(self):
"""Load annoations for SR dataset.
It loads the LQ and GT image path from folders.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
self.data_infos = []
lq_paths = self.scan_folder(self.lq_folder)
gt_paths = self.scan_folder(self.gt_folder)
assert len(lq_paths) == len(gt_paths), (
f'gt and lq datasets have different number of images: '
f'{len(lq_paths)}, {len(gt_paths)}.')
for gt_path in gt_paths:
basename, ext = os.path.splitext(os.path.basename(gt_path))
lq_path = os.path.join(self.lq_folder,
(f'{self.filename_tmpl.format(basename)}'
f'{ext}'))
assert lq_path in lq_paths, f'{lq_path} is not in lq_paths.'
self.data_infos.append(dict(lq_path=lq_path, gt_path=gt_path))
return self.data_infos
......@@ -66,22 +66,37 @@ class DictDataset(paddle.io.Dataset):
class DictDataLoader():
def __init__(self, dataset, batch_size, is_train, num_workers=4):
def __init__(self,
dataset,
batch_size,
is_train,
num_workers=4,
distributed=True):
self.dataset = DictDataset(dataset)
place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
sampler = DistributedBatchSampler(self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader(self.dataset,
batch_sampler=sampler,
places=place,
num_workers=num_workers)
if distributed:
sampler = DistributedBatchSampler(
self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader(self.dataset,
batch_sampler=sampler,
places=place,
num_workers=num_workers)
else:
self.dataloader = paddle.io.DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False,
places=place,
num_workers=num_workers)
self.batch_size = batch_size
......@@ -117,12 +132,20 @@ class DictDataLoader():
return current_items
def build_dataloader(cfg, is_train=True):
dataset = DATASETS.get(cfg.name)(cfg)
def build_dataloader(cfg, is_train=True, distributed=True):
cfg_ = cfg.copy()
batch_size = cfg_.pop('batch_size', 1)
num_workers = cfg_.pop('num_workers', 0)
name = cfg_.pop('name')
batch_size = cfg.get('batch_size', 1)
num_workers = cfg.get('num_workers', 0)
dataset = DATASETS.get(name)(**cfg_)
dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)
dataloader = DictDataLoader(dataset,
batch_size,
is_train,
num_workers,
distributed=distributed)
return dataloader
......@@ -13,35 +13,38 @@
# limitations under the License.
import cv2
import os.path
from .base_dataset import BaseDataset, get_transform
from .transforms.makeup_transforms import get_makeup_transform
import paddle.vision.transforms as T
from PIL import Image
import random
import os.path
import numpy as np
from PIL import Image
import paddle
import paddle.vision.transforms as T
from .base_dataset import BaseDataset
from ..utils.preprocess import *
from .builder import DATASETS
@DATASETS.register()
class MakeupDataset(BaseDataset):
def __init__(self, cfg):
"""Initialize this dataset class.
class MakeupDataset(paddle.io.Dataset):
def __init__(self, dataroot, phase, trans_size, cls_list):
"""Initialize psgan dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
Args:
dataroot (str): Directory of dataset.
phase (str): 'train' or 'test'.
"""
BaseDataset.__init__(self, cfg)
self.image_path = cfg.dataroot
self.mode = cfg.phase
self.transform = get_makeup_transform(cfg)
self.image_path = dataroot
self.mode = phase
self.trans_size = trans_size
self.cls_list = cls_list
self.transform = self.build_makeup_transform()
self.norm = T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
self.transform_mask = get_makeup_transform(cfg, pic="mask")
self.trans_size = cfg.trans_size
self.cls_list = cfg.cls_list
self.transform_mask = self.build_makeup_transform("mask")
self.trans_size = trans_size
self.cls_A = self.cls_list[0]
self.cls_B = self.cls_list[1]
for cls in self.cls_list:
......@@ -72,6 +75,18 @@ class MakeupDataset(BaseDataset):
getattr(self, cls + "_mask_filenames").append(splits[1])
getattr(self, cls + "_lmks_filenames").append(splits[2])
def build_makeup_transform(self, pic="image"):
if pic == "image":
transform = T.Compose([
T.Resize(size=self.trans_size),
T.Transpose(),
])
else:
transform = T.Resize(size=self.trans_size,
interpolation=cv2.INTER_NEAREST)
return transform
def __getitem__(self, index):
"""Return MANet and MDNet needed params.
......
......@@ -12,65 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import paddle
import os.path
from .base_dataset import BaseDataset, get_params, get_transform
from .image_folder import make_dataset
from .builder import DATASETS
from .transforms.builder import build_transforms
from .base_dataset import BaseDataset
@DATASETS.register()
class PairedDataset(BaseDataset):
"""A dataset class for paired image dataset.
"""
def __init__(self, cfg):
def __init__(self, dataroot, preprocess):
"""Initialize this dataset class.
Args:
cfg (dict): configs of datasets.
"""
BaseDataset.__init__(self, cfg)
self.dir_AB = os.path.join(cfg.dataroot,
cfg.phase) # get the image directory
self.AB_paths = sorted(make_dataset(
self.dir_AB, cfg.max_dataset_size)) # get image paths
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
self.transforms = build_transforms(cfg.transforms)
def __getitem__(self, index):
"""Return a data point and its metadata information.
dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) - - an image in the input domain
B (tensor) - - its corresponding image in the target domain
A_paths (str) - - image paths
B_paths (str) - - image paths (same as A_paths)
"""
# read a image given a random integer index
AB_path = self.AB_paths[index]
AB = cv2.cvtColor(cv2.imread(AB_path), cv2.COLOR_BGR2RGB)
# split AB image into A and B
h, w = AB.shape[:2]
# w, h = AB.size
w2 = int(w / 2)
super(PairedDataset, self).__init__(preprocess)
self.dataroot = dataroot
self.data_infos = self.prepare_data_infos()
A = AB[:h, :w2, :]
B = AB[:h, w2:, :]
def prepare_data_infos(self):
"""Load paired image paths.
# apply the same transform to both A and B
A, B = self.transforms((A, B))
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
Returns:
list[dict]: List that contains paired image paths.
"""
data_infos = []
pair_paths = sorted(self.scan_folder(self.dataroot))
for pair_path in pair_paths:
data_infos.append(dict(pair_path=pair_path))
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.AB_paths)
return data_infos
from .io import LoadImageFromFile
from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip,
PairedRandomVerticalFlip, PairedRandomTransposeHW,
SRPairedRandomCrop, SplitPairedImage)
from .builder import build_preprocess
......@@ -12,52 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import paddle
import paddle.nn as nn
import copy
import traceback
from .generators.builder import build_generator
from .base_model import BaseModel
from .losses import GANLoss
from .builder import MODELS
from ...utils.registry import Registry, build_from_config
LOAD_PIPELINE = Registry("LOAD_PIPELINE")
TRANSFORMS = Registry("TRANSFORM")
PREPROCESS = Registry("PREPROCESS")
@MODELS.register()
class SRGANModel(BaseModel):
def __init__(self, cfg):
super(SRGANModel, self).__init__(cfg)
# define networks
self.model_names = ['G']
class Compose(object):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
self.netG = build_generator(cfg.model.generator)
self.visual_names = ['LQ', 'GT', 'fake_H']
Args:
functions (list[callable]): List of functions to compose.
# TODO: support srgan train.
if False:
# self.netD = build_discriminator(cfg.model.discriminator)
self.netG.train()
# self.netD.train()
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
"""
def __init__(self, functions):
self.functions = functions
Parameters:
input (dict): include the data itself and its metadata information.
def __call__(self, datas):
The option 'direction' can be used to swap images in domain A and domain B.
"""
for func in self.functions:
try:
datas = func(datas)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform fuction [{}] with error: "
"{} and stack:\n{}".format(func, e, str(stack_info)))
raise RuntimeError
return datas
# AtoB = self.opt.dataset.train.direction == 'AtoB'
if 'A' in input:
self.LQ = paddle.to_tensor(input['A'])
if 'B' in input:
self.GT = paddle.to_tensor(input['B'])
if 'A_paths' in input:
self.image_paths = input['A_paths']
def forward(self):
self.fake_H = self.netG(self.LQ)
def build_preprocess(cfg):
preproccess = []
if not isinstance(cfg, (list, tuple)):
cfg = [cfg]
def optimize_parameters(self, step):
pass
for cfg_ in cfg:
process = build_from_config(cfg_, PREPROCESS)
preproccess.append(process)
preproccess = Compose(preproccess)
return preproccess
import cv2
from .builder import PREPROCESS
@PREPROCESS.register()
class LoadImageFromFile(object):
"""Load image from file.
Args:
key (str): Keys in results to find corresponding path. Default: 'image'.
flag (str): Loading flag for images. Default: -1.
to_rgb (str): Convert img to 'rgb' format. Default: True.
backend (str): io backend where images are store. Default: None.
save_original_img (bool): If True, maintain a copy of the image in
`results` dict with name of `f'ori_{key}'`. Default: False.
kwargs (dict): Args for file client.
"""
def __init__(self,
key='image',
flag=-1,
to_rgb=True,
save_original_img=False,
backend=None,
**kwargs):
self.key = key
self.flag = flag
self.to_rgb = to_rgb
self.backend = backend
self.save_original_img = save_original_img
self.kwargs = kwargs
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
filepath = str(results[f'{self.key}_path'])
#TODO: use file client to manage io backend
# such as opencv, pil, imdb
img = cv2.imread(filepath, self.flag)
if self.to_rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results[self.key] = img
results[f'{self.key}_path'] = filepath
results[f'{self.key}_ori_shape'] = img.shape
if self.save_original_img:
results[f'ori_{self.key}'] = img.copy()
return results
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numbers
import collections
import paddle.vision.transforms as T
import paddle.vision.transforms.functional as F
from .builder import TRANSFORMS, build_from_config
from .builder import PREPROCESS
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
TRANSFORMS.register(T.Resize)
TRANSFORMS.register(T.RandomCrop)
TRANSFORMS.register(T.RandomHorizontalFlip)
TRANSFORMS.register(T.RandomVerticalFlip)
TRANSFORMS.register(T.Normalize)
TRANSFORMS.register(T.Transpose)
@PREPROCESS.register()
class Transforms():
def __init__(self, pipeline, input_keys):
self.input_keys = input_keys
self.transforms = []
for transform_cfg in pipeline:
self.transforms.append(build_from_config(transform_cfg, TRANSFORMS))
def __call__(self, datas):
data = []
for k in self.input_keys:
data.append(datas[k])
data = tuple(data)
for transform in self.transforms:
data = transform(data)
if hasattr(transform, 'params') and isinstance(
transform.params, dict):
datas.update(transform.params)
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
return datas
@PREPROCESS.register()
class SplitPairedImage:
def __init__(self, key, paired_keys=['A', 'B']):
self.key = key
self.paired_keys = paired_keys
def __call__(self, datas):
# split AB image into A and B
h, w = datas[self.key].shape[:2]
# w, h = AB.size
w2 = int(w / 2)
a, b = self.paired_keys
datas[a] = datas[self.key][:h, :w2, :]
datas[b] = datas[self.key][:h, w2:, :]
datas[a + '_path'] = datas[self.key + '_path']
datas[b + '_path'] = datas[self.key + '_path']
return datas
@TRANSFORMS.register()
class PairedRandomCrop(T.RandomCrop):
def __init__(self, size, keys=None):
super().__init__(size, keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
params = {}
params['crop_prams'] = self._get_param(image, self.size)
return params
def _apply_image(self, img):
i, j, h, w = self.params['crop_prams']
return F.crop(img, i, j, h, w)
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys)
def _get_params(self, inputs):
params = {}
params['flip'] = random.random() < self.prob
return params
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
return image
@TRANSFORMS.register()
class PairedRandomVerticalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys)
def _get_params(self, inputs):
params = {}
params['flip'] = random.random() < self.prob
return params
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
return image
@TRANSFORMS.register()
class PairedRandomTransposeHW(T.BaseTransform):
"""Randomly transpose images in H and W dimensions with a probability.
(TransposeHW = horizontal flip + anti-clockwise rotatation by 90 degrees)
When used with horizontal/vertical flips, it serves as a way of rotation
augmentation.
It also supports randomly transposing a list of images.
Required keys are the keys in attributes "keys", added or modified keys are
"transpose" and the keys in attributes "keys".
Args:
prob (float): The propability to transpose the images.
keys (list[str]): The images to be transposed.
"""
def __init__(self, prob=0.5, keys=None):
self.keys = keys
self.prob = prob
def _get_params(self, inputs):
params = {}
params['transpose'] = random.random() < self.prob
return params
def _apply_image(self, image):
if self.params['transpose']:
image = image.transpose(1, 0, 2)
return image
@TRANSFORMS.register()
class SRPairedRandomCrop(T.BaseTransform):
"""Super resolution random crop.
It crops a pair of lq and gt images with corresponding locations.
It also supports accepting lq list and gt list.
Required keys are "scale", "lq", and "gt",
added or modified keys are "lq" and "gt".
Args:
scale (int): model upscale factor.
gt_patch_size (int): cropped gt patch size.
"""
def __init__(self, scale, gt_patch_size, keys=None):
self.gt_patch_size = gt_patch_size
self.scale = scale
self.keys = keys
def __call__(self, inputs):
"""inputs must be (lq_img, gt_img)"""
scale = self.scale
lq_patch_size = self.gt_patch_size // scale
lq = inputs[0]
gt = inputs[1]
h_lq, w_lq, _ = lq.shape
h_gt, w_gt, _ = gt.shape
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError('scale size not match')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError('lq size error')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
lq = lq[top:top + lq_patch_size, left:left + lq_patch_size, ...]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
gt = gt[top_gt:top_gt + self.gt_patch_size,
left_gt:left_gt + self.gt_patch_size, ...]
outputs = (lq, gt)
return outputs
......@@ -12,54 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import paddle
from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset
from .base_dataset import BaseDataset
from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register()
class SingleDataset(BaseDataset):
"""
"""
def __init__(self, cfg):
"""Initialize this dataset class.
def __init__(self, dataroot, preprocess):
"""Initialize single dataset class.
Args:
cfg (dict) -- stores all the experiment flags
dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config.
"""
BaseDataset.__init__(self, cfg)
self.A_paths = sorted(make_dataset(cfg.dataroot, cfg.max_dataset_size))
input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.transform = build_transforms(self.cfg.transforms)
def __getitem__(self, index):
"""Return a data point and its metadata information.
super(SingleDataset).__init__(self, preprocess)
self.dataroot = dataroot
self.data_infos = self.prepare_data_infos()
Parameters:
index - - a random integer for data indexing
def prepare_data_infos(self):
"""prepare image paths from a folder.
Returns a dictionary that contains A and A_paths
A(tensor) - - an image in one domain
A_paths(str) - - the path of the image
Returns:
list[dict]: List that contains paired image paths.
"""
A_path = self.A_paths[index]
A_img = cv2.cvtColor(cv2.imread(A_path), cv2.COLOR_BGR2RGB)
A = self.transform(A_img)
return {'A': A, 'A_paths': A_path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.A_paths)
data_infos = []
paths = sorted(self.scan_folder(self.dataroot))
for path in paths:
data_infos.append(dict(A_path=path))
def get_path_by_indexs(self, indexs):
if isinstance(indexs, paddle.Tensor):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
current_paths.append(self.A_paths[index])
return current_paths
return data_infos
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import random
import numpy as np
import paddle.vision.transforms as transform
from pathlib import Path
from paddle.io import Dataset
from .builder import DATASETS
def scandir(dir_path, suffix=None, recursive=False):
"""Scan a directory to find the interested files.
"""
if isinstance(dir_path, (str, Path)):
dir_path = str(dir_path)
else:
raise TypeError('"dir_path" must be a string or Path object')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = os.path.relpath(entry.path, root)
if suffix is None:
yield rel_path
elif rel_path.endswith(suffix):
yield rel_path
else:
if recursive:
yield from _scandir(entry.path,
suffix=suffix,
recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
"""
assert len(folders) == 2, (
'The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, (
'The len of keys should be 2 with [input_key, gt_key]. '
f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
assert len(input_paths) == len(gt_paths), (
f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for gt_path in gt_paths:
basename, ext = os.path.splitext(os.path.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = os.path.join(input_folder, input_name)
assert input_name in input_paths, (f'{input_name} is not in '
f'{input_key}_paths.')
gt_path = os.path.join(gt_folder, gt_path)
paths.append(
dict([(f'{input_key}_path', input_path),
(f'{gt_key}_path', gt_path)]))
return paths
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
"""Paired random crop.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(
f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
img_lqs = [
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
for v in img_lqs
]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
for v in img_gts
]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip:
cv2.flip(img, 1, img)
if vflip:
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip:
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip:
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
return imgs
@DATASETS.register()
class SRImageDataset(Dataset):
"""Paired image dataset for image restoration."""
def __init__(self, cfg):
super(SRImageDataset, self).__init__()
self.cfg = cfg
self.file_client = None
self.io_backend_opt = cfg['io_backend']
self.gt_folder, self.lq_folder = cfg['dataroot_gt'], cfg['dataroot_lq']
if 'filename_tmpl' in cfg:
self.filename_tmpl = cfg['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
#TODO: LielinJiang support lmdb to accelerate io
pass
elif 'meta_info_file' in self.cfg and self.cfg[
'meta_info_file'] is not None:
#TODO: LielinJiang support lmdb to accelerate io
pass
else:
self.paths = paired_paths_from_folder(
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.filename_tmpl)
def __getitem__(self, index):
scale = self.cfg['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
lq_path = self.paths[index]['lq_path']
img_gt = cv2.imread(gt_path).astype(np.float32) / 255.
img_lq = cv2.imread(lq_path).astype(np.float32) / 255.
# augmentation for training
if self.cfg['phase'] == 'train':
gt_size = self.cfg['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.cfg['use_flip'],
self.cfg['use_rot'])
# TODO: color space transform
# BGR to RGB, HWC to CHW, numpy to tensor
permute = transform.Permute()
img_gt = permute(img_gt)
img_lq = permute(img_lq)
return {
'lq': img_lq,
'gt': img_gt,
'lq_path': lq_path,
'gt_path': gt_path
}
def __len__(self):
return len(self.paths)
......@@ -24,14 +24,11 @@ class Compose(object):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
"""
def __init__(self, transforms):
self.transforms = transforms
......
......@@ -12,80 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import random
import os.path
from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset
from .base_dataset import BaseDataset
from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register()
class UnpairedDataset(BaseDataset):
"""
"""
def __init__(self, cfg):
"""Initialize this dataset class.
def __init__(self, dataroot_a, dataroot_b, max_size, is_train, preprocess):
"""Initialize unpaired dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
BaseDataset.__init__(self, cfg)
self.dir_A = os.path.join(cfg.dataroot, cfg.phase +
'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(cfg.dataroot, cfg.phase +
'B') # create a path '/path/to/data/trainB'
self.A_paths = sorted(make_dataset(
self.dir_A,
cfg.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(
self.dir_B,
cfg.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
btoA = self.cfg.direction == 'BtoA'
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
self.transform_A = build_transforms(self.cfg.transforms)
self.transform_B = build_transforms(self.cfg.transforms)
self.reset_paths()
def reset_paths(self):
self.path_dict = {}
dataroot_a (str): Directory of dataset a.
dataroot_b (str): Directory of dataset b.
max_size (int): max size of dataset size.
is_train (int): whether in train mode.
preprocess (list[dict]): A sequence of data preprocess config.
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[
index % self.A_size] # make sure index is within then range
if self.cfg.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
super(UnpairedDataset, self).__init__(preprocess)
self.dir_A = os.path.join(dataroot_a)
self.dir_B = os.path.join(dataroot_b)
self.is_train = is_train
self.data_infos_a = self.prepare_data_infos(self.dir_A)
self.data_infos_b = self.prepare_data_infos(self.dir_B)
self.size_a = len(self.data_infos_a)
self.size_b = len(self.data_infos_b)
def prepare_data_infos(self, dataroot):
"""Load unpaired image paths of one domain.
A_img = cv2.cvtColor(cv2.imread(A_path), cv2.COLOR_BGR2RGB)
B_img = cv2.cvtColor(cv2.imread(B_path), cv2.COLOR_BGR2RGB)
# apply image transformation
A = self.transform_A(A_img)
B = self.transform_B(B_img)
Args:
dataroot (str): Path to the folder root for unpaired images of
one domain.
# return A, B
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
Returns:
list[dict]: List that contains unpaired image paths of one domain.
"""
data_infos = []
paths = sorted(self.scan_folder(dataroot))
for path in paths:
data_infos.append(dict(path=path))
return data_infos
def __getitem__(self, idx):
if self.is_train:
img_a_path = self.data_infos_a[idx % self.size_a]['path']
idx_b = random.randint(0, self.size_b)
img_b_path = self.data_infos_b[idx_b]['path']
datas = dict(A_path=img_a_path, B_path=img_b_path)
else:
img_a_path = self.data_infos_a[idx % self.size_a]['path']
img_b_path = self.data_infos_b[idx % self.size_b]['path']
datas = dict(A_path=img_a_path, B_path=img_b_path)
if self.preprocess:
datas = self.preprocess(datas)
return datas
def __len__(self):
"""Return the total number of images in the dataset.
......@@ -93,4 +81,4 @@ class UnpairedDataset(BaseDataset):
As we have two datasets with potentially different number of images,
we take a maximum of
"""
return max(self.A_size, self.B_size)
return max(self.size_a, self.size_b)
......@@ -27,25 +27,78 @@ from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import makedirs, save, load
from ..utils.timer import TimeAverager
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
class Trainer:
def __init__(self, cfg):
class IterLoader:
def __init__(self, dataloader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 1
# build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train)
@property
def epoch(self):
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
if 'lr_scheduler' in cfg.optimizer:
cfg.optimizer.lr_scheduler.step_per_epoch = len(
self.train_dataloader)
return data
def __len__(self):
return len(self._dataloader)
class Trainer:
"""
# trainer calling logic:
#
# build_model || model(BaseModel)
# | ||
# build_dataloader || dataloader
# | ||
# model.setup_lr_schedulers || lr_scheduler
# | ||
# model.setup_optimizers || optimizers
# | ||
# train loop (model.setup_input + model.train_iter) || train loop
# | ||
# print log (model.get_current_losses) ||
# | ||
# save checkpoint (model.nets) \/
"""
def __init__(self, cfg):
# build model
self.model = build_model(cfg)
self.model = build_model(cfg.model)
# multiple gpus prepare
if ParallelEnv().nranks > 1:
self.distributed_data_parallel()
# build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train)
self.iters_per_epoch = len(self.train_dataloader)
# build lr scheduler
# TODO: has a better way?
if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)
# build optimizers
self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
cfg.optimizer)
# build metrics
self.metrics = None
validate_cfg = cfg.get('validate', None)
if validate_cfg and 'metrics' in validate_cfg:
self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
self.logger = logging.getLogger(__name__)
self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl:
......@@ -54,9 +107,18 @@ class Trainer:
# base config
self.output_dir = cfg.output_dir
self.epochs = cfg.epochs
self.epochs = cfg.get('epochs', None)
if self.epochs:
self.total_iters = self.epochs * self.iters_per_epoch
self.by_epoch = True
else:
self.by_epoch = False
self.total_iters = cfg.total_iters
self.start_epoch = 1
self.current_epoch = 1
self.current_iter = 1
self.inner_iter = 1
self.batch_id = 0
self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval
......@@ -69,10 +131,6 @@ class Trainer:
self.local_rank = ParallelEnv().local_rank
# time count
self.steps_per_epoch = len(self.train_dataloader)
self.total_steps = self.epochs * self.steps_per_epoch
self.time_count = {}
self.best_metric = {}
......@@ -85,117 +143,68 @@ class Trainer:
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
for epoch in range(self.start_epoch, self.epochs + 1):
self.current_epoch = epoch
start_time = step_start_time = time.time()
for i, data in enumerate(self.train_dataloader):
reader_cost_averager.record(time.time() - step_start_time)
self.batch_id = i
# unpack data from dataset and apply preprocessing
# data input should be dict
self.model.set_input(data)
self.model.optimize_parameters()
batch_cost_averager.record(time.time() - step_start_time,
num_samples=self.cfg.get(
'batch_size', 1))
if i % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average()
self.ips = batch_cost_averager.get_ips_average()
self.print_log()
reader_cost_averager.reset()
batch_cost_averager.reset()
if i % self.visual_interval == 0:
self.visual('visual_train')
self.global_steps += 1
step_start_time = time.time()
self.logger.info(
'train one epoch use time: {:.3f} seconds.'.format(time.time() -
start_time))
if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate()
self.model.lr_scheduler.step()
if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1)
self.save(epoch)
def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val,
is_train=False)
iter_loader = IterLoader(self.train_dataloader)
metric_result = {}
while self.current_iter < (self.total_iters + 1):
self.current_epoch = iter_loader.epoch
self.inner_iter = self.current_iter % self.iters_per_epoch
for i, data in enumerate(self.val_dataloader):
self.batch_id = i
self.model.set_input(data)
self.model.test()
visual_results = {}
current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()
start_time = step_start_time = time.time()
data = next(iter_loader)
reader_cost_averager.record(time.time() - step_start_time)
# unpack data from dataset and apply preprocessing
# data input should be dict
self.model.setup_input(data)
self.model.train_iter(self.optimizers)
batch_cost_averager.record(time.time() - step_start_time,
num_samples=self.cfg.get(
'batch_size', 1))
if self.current_iter % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average()
self.ips = batch_cost_averager.get_ips_average()
self.print_log()
reader_cost_averager.reset()
batch_cost_averager.reset()
if self.current_iter % self.visual_interval == 0:
self.visual('visual_train')
step_start_time = time.time()
for j in range(len(current_paths)):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]})
if 'psnr' in self.cfg.validate.metrics:
if 'psnr' not in metric_result:
metric_result['psnr'] = calculate_psnr(
tensor2img(current_visuals['output'][j], (0., 1.)),
tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.psnr)
else:
metric_result['psnr'] += calculate_psnr(
tensor2img(current_visuals['output'][j], (0., 1.)),
tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.psnr)
if 'ssim' in self.cfg.validate.metrics:
if 'ssim' not in metric_result:
metric_result['ssim'] = calculate_ssim(
tensor2img(current_visuals['output'][j], (0., 1.)),
tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.ssim)
else:
metric_result['ssim'] += calculate_ssim(
tensor2img(current_visuals['output'][j], (0., 1.)),
tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.ssim)
self.visual('visual_val',
visual_results=visual_results,
step=self.batch_id)
self.model.lr_scheduler.step()
if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
if self.by_epoch:
temp = self.current_epoch
else:
temp = self.current_iter
if self.validate_interval > -1 and temp % self.validate_interval == 0:
self.test()
for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
if temp % self.weight_interval == 0:
self.save(temp, 'weight', keep=-1)
self.save(temp)
self.logger.info('Epoch {} validate end: {}'.format(
self.current_epoch, metric_result))
self.current_iter += 1
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False)
is_train=False,
distributed=False)
if self.metrics:
for metric in self.metrics.values():
metric.reset()
# data[0]: img, data[1]: img path index
# test batch size must be 1
for i, data in enumerate(self.test_dataloader):
self.batch_id = i
self.model.set_input(data)
self.model.test()
self.model.setup_input(data)
self.model.test_iter(metrics=self.metrics)
visual_results = {}
current_paths = self.model.get_image_paths()
......@@ -217,11 +226,23 @@ class Trainer:
self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader)))
if self.metrics:
for metric_name, metric in self.metrics.items():
self.logger.info("Metric {}: {:.4f}".format(
metric_name, metric.accumulate()))
def print_log(self):
losses = self.model.get_current_losses()
message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id)
message += '%s: %.6f ' % ('lr', self.current_learning_rate)
message = ''
if self.by_epoch:
message += 'Epoch: %d/%d, iter: %d/%d ' % (
self.current_epoch, self.epochs, self.inner_iter,
self.iters_per_epoch)
else:
message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)
message += f'lr: {self.current_learning_rate:.3e} '
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
......@@ -238,9 +259,7 @@ class Trainer:
message += 'ips: %.5f images/s ' % self.ips
if hasattr(self, 'step_time'):
cur_step = self.steps_per_epoch * (self.current_epoch -
1) + self.batch_id
eta = self.step_time * (self.total_steps - cur_step - 1)
eta = self.step_time * (self.total_iters - self.current_iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta)))
message += f'eta: {eta_str}'
......@@ -274,6 +293,7 @@ class Trainer:
min_max = self.cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
image_num = self.cfg.get('image_num', None)
if (image_num is None) or (not self.enable_visualdl):
image_num = 1
......@@ -345,6 +365,14 @@ class Trainer:
state_dicts = load(weight_path)
for net_name, net in self.model.nets.items():
if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name])
self.logger.info(
'Loaded pretrained weight for net {}'.format(net_name))
else:
self.logger.warning(
'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.format(net_name, net_name))
net.set_state_dict(state_dicts[net_name])
def close(self):
......
English (./README.md)
# Usage
To compute the FID score between two datasets, where images of each dataset are contained in an individual folder:
wget https://paddlegan.bj.bcebos.com/InceptionV3.pdparams
```
python test_fid_score.py --image_data_path1 /path/to/dataset1 --image_data_path2 /path/to/dataset2 --inference_model ./InceptionV3.pdparams
```
### Inception-V3 weights converted from torchvision
Download: https://aistudio.baidu.com/aistudio/datasetdetail/51890
This model weights file is converted from official torchvision inception-v3 model. And both BigGAN and StarGAN-v2 is using it to calculate FID score.
Note that this model weights is different from above one (which is converted from tensorflow unofficial version)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
"'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
return img
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
#img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
#out_img = _convert_output_type_range(out_img, img_type)
return out_img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
from compute_fid import *
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--image_data_path1',
type=str,
default='./real',
help='path of image data')
parser.add_argument('--image_data_path2',
type=str,
default='./fake',
help='path of image data')
parser.add_argument('--inference_model',
type=str,
default='./pretrained/params_inceptionV3',
help='path of inference_model.')
parser.add_argument('--use_gpu',
type=bool,
default=True,
help='default use gpu.')
parser.add_argument('--batch_size',
type=int,
default=1,
help='sample number in a batch for inference.')
parser.add_argument(
'--style',
type=str,
help='calculation style: stargan or default (gan-compression style)')
args = parser.parse_args()
return args
def main():
args = parse_args()
path1 = args.image_data_path1
path2 = args.image_data_path2
paths = (path1, path2)
inference_model_path = args.inference_model
batch_size = args.batch_size
fid_value = calculate_fid_given_paths(paths,
inference_model_path,
batch_size,
args.use_gpu,
2048,
style=args.style)
print('FID: ', fid_value)
if __name__ == "__main__":
main()
......@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .psnr_ssim import PSNR, SSIM
from .builder import build_metric
......@@ -12,4 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip, Add, ResizeToScale
import copy
import paddle
from ..utils.registry import Registry
METRICS = Registry("METRIC")
def build_metric(cfg):
cfg_ = cfg.copy()
name = cfg_.pop('name', None)
metric = METRICS.get(name)(**cfg_)
return metric
......@@ -14,8 +14,59 @@
import cv2
import numpy as np
import paddle
from .metric_util import reorder_image, to_y_channel
from .builder import METRICS
@METRICS.register()
class PSNR(paddle.metric.Metric):
def __init__(self, crop_border, input_order='HWC', test_y_channel=False):
self.crop_border = crop_border
self.input_order = input_order
self.test_y_channel = test_y_channel
self.reset()
def reset(self):
self.results = []
def update(self, preds, gts):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
for pred, gt in zip(preds, gts):
value = calculate_psnr(pred, gt, self.crop_border, self.input_order,
self.test_y_channel)
self.results.append(value)
def accumulate(self):
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
def name(self):
return 'PSNR'
@METRICS.register()
class SSIM(PSNR):
def update(self, preds, gts):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
for pred, gt in zip(preds, gts):
value = calculate_ssim(pred, gt, self.crop_border, self.input_order,
self.test_y_channel)
self.results.append(value)
def name(self):
return 'SSIM'
def calculate_psnr(img1,
......@@ -46,6 +97,8 @@ def calculate_psnr(img1,
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')
img1 = img1.copy().astype('float32')
img2 = img2.copy().astype('float32')
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
......@@ -134,6 +187,10 @@ def calculate_ssim(img1,
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')
img1 = img1.copy().astype('float32')[..., ::-1]
img2 = img2.copy().astype('float32')[..., ::-1]
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
......@@ -149,3 +206,81 @@ def calculate_ssim(img1,
for i in range(img1.shape[2]):
ssims.append(_ssim(img1[..., i], img2[..., i]))
return np.array(ssims).mean()
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
"'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
return img
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
return out_img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
......@@ -16,9 +16,9 @@ from .base_model import BaseModel
from .gan_model import GANModel
from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
from .sr_model import SRModel
from .sr_model import BaseSRModel
from .makeup_model import MakeupModel
from .esrgan_model import ESRGAN
from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
......@@ -19,7 +19,7 @@ from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from .criterions.gan_loss import GANLoss
from ..modules.caffevgg import CaffeVGG19
from ..solver import build_optimizer
from ..modules.init import init_weights
......
......@@ -19,24 +19,40 @@ import numpy as np
from collections import OrderedDict
from abc import ABC, abstractmethod
from ..solver.lr_scheduler import build_lr_scheduler
from .criterions.builder import build_criterion
from ..solver import build_lr_scheduler, build_optimizer
from ..metrics import build_metric
from ..utils.visual import tensor2img
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
-- <__init__>: initialize the class.
-- <setup_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <train_iter>: calculate losses, gradients, and update network weights.
# trainer training logic:
#
# build_model || model(BaseModel)
# | ||
# build_dataloader || dataloader
# | ||
# model.setup_lr_schedulers || lr_scheduler
# | ||
# model.setup_optimizers || optimizers
# | ||
# train loop (model.setup_input + model.train_iter) || train loop
# | ||
# print log (model.get_current_losses) ||
# | ||
# save checkpoint (model.nets) \/
"""
def __init__(self, cfg):
def __init__(self):
"""Initialize the BaseModel class.
Args:
cfg (Dict)-- configs of Model.
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <super(YourClass, self).__init__(self, cfg)>
Then, you need to define four lists:
......@@ -47,57 +63,85 @@ class BaseModel(ABC):
If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example.
"""
self.cfg = cfg
self.is_train = cfg.is_train
self.save_dir = os.path.join(
cfg.output_dir,
cfg.model.name) # save all the checkpoints to save_dir
self.losses = OrderedDict()
self.nets = OrderedDict()
self.visual_items = OrderedDict()
self.optimizers = OrderedDict()
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
self.metrics = OrderedDict()
self.losses = OrderedDict()
self.visual_items = OrderedDict()
@abstractmethod
def set_input(self, input):
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
Args:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
"""Run forward pass; called by both functions <train_iter> and <test_iter>."""
pass
@abstractmethod
def optimize_parameters(self):
def train_iter(self, optims=None):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def build_lr_scheduler(self):
self.lr_scheduler = build_lr_scheduler(self.cfg.lr_scheduler)
def test_iter(self, metrics=None):
"""Calculate metrics; called in every test iteration"""
self.eval()
with paddle.no_grad():
self.forward()
self.train()
def setup_train_mode(self, is_train):
self.is_train = is_train
def setup_lr_schedulers(self, cfg):
self.lr_scheduler = build_lr_scheduler(cfg)
return self.lr_scheduler
def setup_optimizers(self, lr, cfg):
if cfg.get('name', None):
cfg_ = cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
self.optimizers['optim'] = build_optimizer(cfg_, lr, parameters)
else:
for opt_name, opt_cfg in cfg.items():
cfg_ = opt_cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
self.optimizers[opt_name] = build_optimizer(
cfg_, lr, parameters)
return self.optimizers
def setup_metrics(self, cfg):
if isinstance(list(cfg.values())[0], dict):
for metric_name, cfg_ in cfg.items():
self.metrics[metric_name] = build_metric(cfg_)
else:
metric = build_metric(cfg)
self.metrics[metric.__class__.__name__] = metric
return self.metrics
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def test(self):
"""Forward function used in test time.
"""Make nets eval mode during test time"""
for net in self.nets.values():
net.eval()
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.no_grad():
self.forward()
self.compute_visuals()
def train(self):
"""Make nets train mode during train time"""
for net in self.nets.values():
net.train()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
......@@ -118,8 +162,8 @@ class BaseModel(ABC):
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Args:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
nets (network list): a list of networks
requires_grad (bool): whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
from ..utils.registry import Registry
......@@ -20,5 +21,7 @@ MODELS = Registry("MODEL")
def build_model(cfg):
model = MODELS.get(cfg.model.name)(cfg)
cfg_ = cfg.copy()
name = cfg_.pop('name', None)
model = MODELS.get(name)(**cfg_)
return model
from .gan_loss import GANLoss
from .perceptual_loss import PerceptualLoss
from .pixel_loss import L1Loss, MSELoss
from .builder import build_criterion
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.registry import Registry
CRITERIONS = Registry('CRITERION')
def build_criterion(cfg):
cfg_ = cfg.copy()
name = cfg_.pop('name')
try:
criterion = CRITERIONS.get(name)(**cfg_)
except Exception as e:
cls_ = CRITERIONS.get(name)
raise RuntimeError('class {} {}'.format(cls_.__name__, e))
return criterion
......@@ -16,29 +16,40 @@ import numpy as np
import paddle
import paddle.nn as nn
from .builder import CRITERIONS
import paddle.nn.functional as F
@CRITERIONS.register()
class GANLoss(nn.Layer):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
def __init__(self,
gan_mode,
target_real_label=1.0,
target_fake_label=0.0,
loss_weight=1.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Args:
gan_mode (str): the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool): label for a real image
target_fake_label (bool): label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
# when loss weight less than zero return None
if loss_weight <= 0:
return None
self.target_real_label = target_real_label
self.target_fake_label = target_fake_label
self.loss_weight = loss_weight
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
......@@ -53,7 +64,7 @@ class GANLoss(nn.Layer):
def get_target_tensor(self, prediction, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
Args:
prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
......@@ -75,16 +86,19 @@ class GANLoss(nn.Layer):
dtype='float32')
target_tensor = self.target_fake_tensor
# target_tensor.stop_gradient = True
return target_tensor
def __call__(self, prediction, target_is_real, is_updating_D=None):
def __call__(self,
prediction,
target_is_real,
is_disc=False,
is_updating_D=None):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
Args:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
is_updating_D (bool) - - if we are in updating D step or not
is_updating_D (bool) - - if we are in updating D step or not
Returns:
the calculated loss.
......@@ -108,4 +122,5 @@ class GANLoss(nn.Layer):
loss = F.softplus(-prediction).mean()
else:
loss = F.softplus(prediction).mean()
return loss
return loss if is_disc else loss * self.loss_weight
import paddle
import paddle.nn as nn
import paddle.vision.models.vgg as vgg
from ppgan.utils.download import get_path_from_url
from .builder import CRITERIONS
class PerceptualVGG(nn.Layer):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): According to the name in this list,
forward function will return the corresponding features. This
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_tyep (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
pretrained_url (str): Path for pretrained weights. Default:
"""
def __init__(
self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
pretrained_url='https://paddlegan.bj.bcebos.com/model/vgg19.pdparams'
):
super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
# get vgg model and load pretrained vgg weight
_vgg = getattr(vgg, vgg_type)()
if pretrained_url:
weight_path = get_path_from_url(pretrained_url)
state_dict = paddle.load(weight_path)
_vgg.load_dict(state_dict)
print('PerceptualVGG loaded pretrained weight.')
num_layers = max(map(int, layer_name_list)) + 1
assert len(_vgg.features) >= num_layers
# only borrow layers that will be used from _vgg to avoid unused params
self.vgg_layers = nn.Sequential(
*list(_vgg.features.children())[:num_layers])
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer(
'mean',
paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]))
# the std is for image with range [-1, 1]
self.register_buffer(
'std',
paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]))
for v in self.vgg_layers.parameters():
v.trainable = False
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for name, module in self.vgg_layers.named_children():
x = module(x)
if name in self.layer_name_list:
output[name] = x.clone()
return output
@CRITERIONS.register()
class PerceptualLoss(nn.Layer):
"""Perceptual loss with commonly used style loss.
Args:
layers_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
5th, 10th and 18th feature layer will be extracted with weight 1.0
in calculting losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplified by the
weight. Default: 1.0.
style_weight (flaot): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplified by the weight.
Default: 1.0.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward fucntion of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
pretrained (str): Path for pretrained weights. Default:
"""
def __init__(
self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
perceptual_weight=1.0,
style_weight=1.0,
norm_img=True,
pretrained='https://paddlegan.bj.bcebos.com/model/vgg19.pdparams',
criterion='l1'):
super(PerceptualLoss, self).__init__()
# when loss weight less than zero return None
if perceptual_weight <= 0 and style_weight <= 0:
return None
self.norm_img = norm_img
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
pretrained_url=pretrained)
if criterion == 'l1':
self.criterion = nn.L1Loss()
else:
raise NotImplementedError(
f'{criterion} criterion has not been supported in'
' this version.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate preceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss *= self.perceptual_weight
else:
percep_loss = None
# calculate style loss
if self.style_weight > 0:
style_loss = 0
for k in x_features.keys():
style_loss += self.criterion(self._gram_mat(
x_features[k]), self._gram_mat(
gt_features[k])) * self.layer_weights[k]
style_loss *= self.style_weight
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (paddle.Tensor): Tensor with shape of (n, c, h, w).
Returns:
paddle.Tensor: Gram matrix.
"""
(n, c, h, w) = x.shape
features = x.reshape([n, c, w * h])
features_t = features.transpose([1, 2])
gram = features.bmm(features_t) / (c * h * w)
return gram
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
from .builder import CRITERIONS
@CRITERIONS.register()
class L1Loss():
"""L1 (mean absolute error, MAE) loss.
Args:
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
return None
self._l1_loss = nn.L1Loss(reduction)
self.loss_weight = loss_weight
self.reduction = reduction
def __call__(self, pred, target, **kwargs):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * self._l1_loss(pred, target)
@CRITERIONS.register()
class MSELoss():
"""MSE (L2) loss.
Args:
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
return None
self._l2_loss = nn.MSELoss(reduction)
self.loss_weight = loss_weight
self.reduction = reduction
def __call__(self, pred, target, **kwargs):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * self._l2_loss(pred, target)
@CRITERIONS.register()
class BCEWithLogitsLoss():
"""BCE loss.
Args:
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
# when loss weight less than zero return None
if loss_weight <= 0:
return None
self._bce_loss = nn.BCEWithLogitsLoss(reduction=reduction)
self.loss_weight = loss_weight
self.reduction = reduction
def __call__(self, pred, target, **kwargs):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * self._bce_loss(pred, target)
......@@ -18,9 +18,8 @@ from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from .criterions import build_criterion
from ..solver import build_optimizer
from ..modules.init import init_weights
from ..utils.image_pool import ImagePool
......@@ -30,61 +29,62 @@ class CycleGANModel(BaseModel):
"""
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
The model training requires '--dataset_mode unaligned' dataset.
By default, it uses a '--netG resnet_9blocks' ResNet generator,
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
and a least-square GANs objective ('--gan_mode lsgan').
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
"""
def __init__(self, cfg):
def __init__(self,
generator,
discriminator=None,
cycle_criterion=None,
idt_criterion=None,
gan_criterion=None,
pool_size=50,
direction='a2b',
lambda_a=10.,
lambda_b=10.):
"""Initialize the CycleGAN class.
Parameters:
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
cycle_criterion (dict): config of cycle criterion.
"""
super(CycleGANModel, self).__init__(cfg)
super(CycleGANModel, self).__init__()
self.direction = direction
# define networks (both Generators and discriminators)
self.lambda_a = lambda_a
self.lambda_b = lambda_b
# define generators
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.nets['netG_A'] = build_generator(cfg.model.generator)
self.nets['netG_B'] = build_generator(cfg.model.generator)
self.nets['netG_A'] = build_generator(generator)
self.nets['netG_B'] = build_generator(generator)
init_weights(self.nets['netG_A'])
init_weights(self.nets['netG_B'])
if self.is_train: # define discriminators
self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
# define discriminators
if discriminator:
self.nets['netD_A'] = build_discriminator(discriminator)
self.nets['netD_B'] = build_discriminator(discriminator)
init_weights(self.nets['netD_A'])
init_weights(self.nets['netD_B'])
if self.is_train:
if cfg.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert (
cfg.dataset.train.input_nc == cfg.dataset.train.output_nc)
# create image buffer to store previously generated images
self.fake_A_pool = ImagePool(cfg.dataset.train.pool_size)
# create image buffer to store previously generated images
self.fake_B_pool = ImagePool(cfg.dataset.train.pool_size)
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode)
self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss()
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG_A'].parameters() +
self.nets['netG_B'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD_A'].parameters() +
self.nets['netD_B'].parameters())
def set_input(self, input):
# create image buffer to store previously generated images
self.fake_A_pool = ImagePool(pool_size)
# create image buffer to store previously generated images
self.fake_B_pool = ImagePool(pool_size)
# define loss functions
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
if cycle_criterion:
self.cycle_criterion = build_criterion(cycle_criterion)
if idt_criterion:
self.idt_criterion = build_criterion(idt_criterion)
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Args:
......@@ -92,8 +92,8 @@ class CycleGANModel(BaseModel):
The option 'direction' can be used to swap domain A and domain B.
"""
mode = 'train' if self.is_train else 'test'
AtoB = self.cfg.dataset[mode].direction == 'AtoB'
AtoB = self.direction == 'a2b'
if AtoB:
if 'A' in input:
......@@ -134,20 +134,22 @@ class CycleGANModel(BaseModel):
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Args:
netD (Layer): the discriminator D
real (paddle.Tensor): real images
fake (paddle.Tensor): images generated by a generator
Return:
the discriminator loss.
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
loss_D_real = self.gan_criterion(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
loss_D_fake = self.gan_criterion(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
......@@ -170,16 +172,13 @@ class CycleGANModel(BaseModel):
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.cfg.lambda_identity
lambda_A = self.cfg.lambda_A
lambda_B = self.cfg.lambda_B
# Identity loss
if lambda_idt > 0:
if self.idt_criterion:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.nets['netG_A'](self.real_B)
self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_B) * lambda_B * lambda_idt
self.loss_idt_A = self.idt_criterion(self.idt_A,
self.real_B) * self.lambda_b
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.nets['netG_B'](self.real_A)
......@@ -187,24 +186,24 @@ class CycleGANModel(BaseModel):
self.visual_items['idt_A'] = self.idt_A
self.visual_items['idt_B'] = self.idt_B
self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_A) * lambda_A * lambda_idt
self.loss_idt_B = self.idt_criterion(self.idt_B,
self.real_A) * self.lambda_a
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_B),
True)
self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_B),
True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_A),
True)
self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_A),
True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A
self.loss_cycle_A = self.cycle_criterion(self.rec_A,
self.real_A) * self.lambda_a
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
self.loss_cycle_B = self.cycle_criterion(self.rec_B,
self.real_B) * self.lambda_b
self.losses['G_idt_A_loss'] = self.loss_idt_A
self.losses['G_idt_B_loss'] = self.loss_idt_B
......@@ -217,7 +216,7 @@ class CycleGANModel(BaseModel):
self.loss_G.backward()
def optimize_parameters(self):
def train_iter(self, optimizers=None):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
# compute fake images and reconstruction images.
......@@ -227,19 +226,19 @@ class CycleGANModel(BaseModel):
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
False)
# set G_A and G_B's gradients to zero
self.optimizers['optimizer_G'].clear_grad()
optimizers['optimG'].clear_grad()
# calculate gradients for G_A and G_B
self.backward_G()
# update G_A and G_B's weights
self.optimizers['optimizer_G'].step()
self.optimizers['optimG'].step()
# D_A and D_B
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True)
# set D_A and D_B's gradients to zero
self.optimizers['optimizer_D'].clear_grad()
optimizers['optimD'].clear_grad()
# calculate gradients for D_A
self.backward_D_A()
# calculate graidents for D_B
self.backward_D_B()
# update D_A and D_B's weights
self.optimizers['optimizer_D'].step()
optimizers['optimD'].step()
......@@ -18,70 +18,54 @@ from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from ..solver import build_optimizer
from .criterions import build_criterion
from ..modules.init import init_weights
@MODELS.register()
class DCGANModel(BaseModel):
""" This class implements the DCGAN model, for learning a distribution from input images.
The model training requires dataset.
By default, it uses a '--netG DCGenerator' generator,
a '--netD DCDiscriminator' discriminator,
and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
"""
This class implements the DCGAN model, for learning a distribution from input images.
DCGAN paper: https://arxiv.org/pdf/1511.06434
"""
def __init__(self, cfg):
def __init__(self, generator, discriminator=None, gan_criterion=None):
"""Initialize the DCGAN class.
Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
pixel_criterion (dict): config of pixel criterion.
gan_criterion (dict): config of gan criterion.
"""
super(DCGANModel, self).__init__(cfg)
super(DCGANModel, self).__init__()
self.gen_cfg = generator
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator)
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'])
self.cfg = cfg
if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator)
self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD'])
if self.is_train:
self.losses = {}
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode)
# build optimizers
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
def set_input(self, input):
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
Args:
input (dict): include the data itself and its metadata information.
"""
# get 1-channel gray image, or 3-channel color image
self.real = paddle.to_tensor(input['A'][:,0:self.cfg.model.generator.input_nc,:,:])
self.image_paths = input['A_paths']
self.real = paddle.to_tensor(input['A'])
self.image_paths = input['A_path']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
"""Run forward pass; called by both functions <train_iter> and <test_iter>."""
# generate random noise and fake image
self.z = paddle.rand(shape=(self.real.shape[0],self.cfg.model.generator.input_nz,1,1))
self.fake = self.nets['netG'](self.z)
self.z = paddle.rand(shape=(self.real.shape[0], self.gen_cfg.input_nz,
1, 1))
self.fake = self.nets['netG'](self.z)
# put items to visual dict
self.visual_items['real'] = self.real
......@@ -91,10 +75,10 @@ class DCGANModel(BaseModel):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake
pred_fake = self.nets['netD'](self.fake.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
self.loss_D_fake = self.gan_criterion(pred_fake, False)
pred_real = self.nets['netD'](self.real)
self.loss_D_real = self.criterionGAN(pred_real, True)
self.loss_D_real = self.gan_criterion(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
......@@ -108,7 +92,7 @@ class DCGANModel(BaseModel):
"""Calculate GAN loss for the generator"""
# G(A) should fake the discriminator
pred_fake = self.nets['netD'](self.fake)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
self.loss_G_GAN = self.gan_criterion(pred_fake, True)
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN
......@@ -117,7 +101,7 @@ class DCGANModel(BaseModel):
self.losses['G_adv_loss'] = self.loss_G_GAN
def optimize_parameters(self):
def train_iter(self, optimizers=None):
# compute fake images: G(A)
self.forward()
......@@ -133,4 +117,4 @@ class DCGANModel(BaseModel):
self.set_requires_grad(self.nets['netG'], True)
self.optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
\ No newline at end of file
self.optimizers['optimizer_G'].step()
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .vgg_discriminator import VGGDiscriminator128
from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification
from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator
......
import paddle.nn as nn
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class VGGDiscriminator128(nn.Layer):
"""VGG style discriminator with input size 128 x 128.
It is used to train SRGAN and ESRGAN.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features.
Default: 64.
"""
def __init__(self, in_channels, num_feat, norm_layer='batch'):
super(VGGDiscriminator128, self).__init__()
self.conv0_0 = nn.Conv2D(in_channels, num_feat, 3, 1, 1, bias_attr=True)
self.conv0_1 = nn.Conv2D(num_feat, num_feat, 4, 2, 1, bias_attr=False)
self.bn0_1 = nn.BatchNorm2D(num_feat)
self.conv1_0 = nn.Conv2D(num_feat,
num_feat * 2,
3,
1,
1,
bias_attr=False)
self.bn1_0 = nn.BatchNorm2D(num_feat * 2)
self.conv1_1 = nn.Conv2D(num_feat * 2,
num_feat * 2,
4,
2,
1,
bias_attr=False)
self.bn1_1 = nn.BatchNorm2D(num_feat * 2)
self.conv2_0 = nn.Conv2D(num_feat * 2,
num_feat * 4,
3,
1,
1,
bias_attr=False)
self.bn2_0 = nn.BatchNorm2D(num_feat * 4)
self.conv2_1 = nn.Conv2D(num_feat * 4,
num_feat * 4,
4,
2,
1,
bias_attr=False)
self.bn2_1 = nn.BatchNorm2D(num_feat * 4)
self.conv3_0 = nn.Conv2D(num_feat * 4,
num_feat * 8,
3,
1,
1,
bias_attr=False)
self.bn3_0 = nn.BatchNorm2D(num_feat * 8)
self.conv3_1 = nn.Conv2D(num_feat * 8,
num_feat * 8,
4,
2,
1,
bias_attr=False)
self.bn3_1 = nn.BatchNorm2D(num_feat * 8)
self.conv4_0 = nn.Conv2D(num_feat * 8,
num_feat * 8,
3,
1,
1,
bias_attr=False)
self.bn4_0 = nn.BatchNorm2D(num_feat * 8)
self.conv4_1 = nn.Conv2D(num_feat * 8,
num_feat * 8,
4,
2,
1,
bias_attr=False)
self.bn4_1 = nn.BatchNorm2D(num_feat * 8)
self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
self.linear2 = nn.Linear(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x):
assert x.shape[2] == 128 and x.shape[3] == 128, (
f'Input spatial size must be 128x128, '
f'but received {x.shape}.')
feat = self.lrelu(self.conv0_0(x))
feat = self.lrelu(self.bn0_1(
self.conv0_1(feat))) # output spatial size: (64, 64)
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
feat = self.lrelu(self.bn1_1(
self.conv1_1(feat))) # output spatial size: (32, 32)
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
feat = self.lrelu(self.bn2_1(
self.conv2_1(feat))) # output spatial size: (16, 16)
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
feat = self.lrelu(self.bn3_1(
self.conv3_1(feat))) # output spatial size: (8, 8)
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
feat = self.lrelu(self.bn4_1(
self.conv4_1(feat))) # output spatial size: (4, 4)
feat = feat.reshape([feat.shape[0], -1])
feat = self.lrelu(self.linear1(feat))
out = self.linear2(feat)
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .sr_model import BaseSRModel
from .builder import MODELS
from .criterions import build_criterion
@MODELS.register()
class ESRGAN(BaseSRModel):
"""
This class implements the ESRGAN model.
ESRGAN paper: https://arxiv.org/pdf/1809.00219.pdf
"""
def __init__(self,
generator,
discriminator=None,
pixel_criterion=None,
perceptual_criterion=None,
gan_criterion=None):
"""Initialize the ESRGAN class.
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
pixel_criterion (dict): config of pixel criterion.
perceptual_criterion (dict): config of perceptual criterion.
gan_criterion (dict): config of gan criterion.
"""
super(ESRGAN, self).__init__(generator)
self.nets['generator'] = build_generator(generator)
if discriminator:
self.nets['discriminator'] = build_discriminator(discriminator)
if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)
if perceptual_criterion:
self.perceptual_criterion = build_criterion(perceptual_criterion)
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
def train_iter(self, optimizers=None):
self.set_requires_grad(self.nets['discriminator'], False)
optimizers['optimG'].clear_grad()
l_total = 0
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
# pixel loss
if self.pixel_criterion:
l_pix = self.pixel_criterion(self.output, self.gt)
l_total += l_pix
self.losses['loss_pix'] = l_pix
if self.perceptual_criterion:
l_g_percep, l_g_style = self.perceptual_criterion(
self.output, self.gt)
# l_total += l_pix
if l_g_percep is not None:
l_total += l_g_percep
self.losses['loss_percep'] = l_g_percep
if l_g_style is not None:
l_total += l_g_style
self.losses['loss_style'] = l_g_style
# gan loss (relativistic gan)
real_d_pred = self.nets['discriminator'](self.gt).detach()
fake_g_pred = self.nets['discriminator'](self.output)
l_g_real = self.gan_criterion(real_d_pred - paddle.mean(fake_g_pred),
False,
is_disc=False)
l_g_fake = self.gan_criterion(fake_g_pred - paddle.mean(real_d_pred),
True,
is_disc=False)
l_g_gan = (l_g_real + l_g_fake) / 2
l_total += l_g_gan
self.losses['l_g_gan'] = l_g_gan
l_total.backward()
optimizers['optimG'].step()
self.set_requires_grad(self.nets['discriminator'], True)
optimizers['optimD'].clear_grad()
# real
fake_d_pred = self.nets['discriminator'](self.output).detach()
real_d_pred = self.nets['discriminator'](self.gt)
l_d_real = self.gan_criterion(
real_d_pred - paddle.mean(fake_d_pred), True, is_disc=True) * 0.5
# fake
fake_d_pred = self.nets['discriminator'](self.output.detach())
l_d_fake = self.gan_criterion(
fake_d_pred - paddle.mean(real_d_pred.detach()),
False,
is_disc=True) * 0.5
(l_d_real + l_d_fake).backward()
optimizers['optimD'].step()
self.losses['l_d_real'] = l_d_real
self.losses['l_d_fake'] = l_d_fake
self.losses['out_d_real'] = paddle.mean(real_d_pred.detach())
self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach())
......@@ -19,7 +19,7 @@ from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from .criterions.gan_loss import GANLoss
from ..solver import build_optimizer
from ..modules.init import init_weights
......@@ -80,9 +80,11 @@ class GANModel(BaseModel):
if not isinstance(input, dict):
input = {'img': input}
self.D_real_inputs = [paddle.to_tensor(input['img'])]
if 'class_id' in input: # n class input
if 'class_id' in input: # n class input
self.n_class = self.nets['netG'].n_class
self.D_real_inputs += [paddle.to_tensor(input['class_id'], dtype='int64')]
self.D_real_inputs += [
paddle.to_tensor(input['class_id'], dtype='int64')
]
else:
self.n_class = 0
......@@ -97,15 +99,18 @@ class GANModel(BaseModel):
rows_num = (batch_size - 1) // self.samples_every_row + 1
class_ids = paddle.randint(0, self.n_class, [rows_num, 1])
class_ids = class_ids.tile([1, self.samples_every_row])
class_ids = class_ids.reshape([-1,])[:batch_size].detach()
class_ids = class_ids.reshape([
-1,
])[:batch_size].detach()
self.G_fixed_inputs[1] = class_ids.detach()
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_imgs = self.nets['netG'](*self.G_inputs) # G(img, class_id)
self.fake_imgs = self.nets['netG'](*self.G_inputs) # G(img, class_id)
# put items to visual dict
self.visual_items['fake_imgs'] = make_grid(self.fake_imgs, self.samples_every_row).detach()
self.visual_items['fake_imgs'] = make_grid(
self.fake_imgs, self.samples_every_row).detach()
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
......@@ -118,7 +123,8 @@ class GANModel(BaseModel):
pred_fake = self.nets['netD'](*self.D_fake_inputs)
# Real
real_imgs = self.D_real_inputs[0]
self.visual_items['real_imgs'] = make_grid(real_imgs, self.samples_every_row).detach()
self.visual_items['real_imgs'] = make_grid(
real_imgs, self.samples_every_row).detach()
pred_real = self.nets['netD'](*self.D_real_inputs)
self.loss_D_fake = self.criterionGAN(pred_fake, False, True)
......@@ -126,7 +132,8 @@ class GANModel(BaseModel):
# combine loss and calculate gradients
if self.cfg.model.gan_mode in ['vanilla', 'lsgan']:
self.loss_D = self.loss_D + (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D = self.loss_D + (self.loss_D_fake +
self.loss_D_real) * 0.5
else:
self.loss_D = self.loss_D + self.loss_D_fake + self.loss_D_real
......@@ -179,7 +186,7 @@ class GANModel(BaseModel):
if self.step % self.visual_interval == 0:
with paddle.no_grad():
self.visual_items['fixed_generated_imgs'] = make_grid(
self.nets['netG'](*self.G_fixed_inputs), self.samples_every_row
)
self.nets['netG'](*self.G_fixed_inputs),
self.samples_every_row)
self.step += 1
......@@ -15,8 +15,7 @@ import os
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import vgg16
from paddle.utils.download import get_path_from_url
from .base_model import BaseModel
......@@ -24,12 +23,10 @@ from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from .criterions import build_criterion
from ..modules.init import init_weights
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool
from ..utils.preprocess import *
from ..datasets.makeup_dataset import MakeupDataset
VGGFACE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/vggface.pdparams'
......@@ -39,18 +36,32 @@ class MakeupModel(BaseModel):
"""
PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf
"""
def __init__(self, cfg):
def __init__(self,
generator,
discriminator=None,
cycle_criterion=None,
idt_criterion=None,
gan_criterion=None,
l1_criterion=None,
l2_criterion=None,
pool_size=50,
direction='a2b',
lambda_a=10.,
lambda_b=10.,
is_train=True):
"""Initialize the PSGAN class.
Parameters:
cfg (dict)-- config of model.
"""
super(MakeupModel, self).__init__(cfg)
super(MakeupModel, self).__init__()
self.lambda_a = lambda_a
self.lambda_b = lambda_b
self.is_train = is_train
# define networks (both Generators and discriminators)
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.nets['netG'] = build_generator(cfg.model.generator)
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)
if self.is_train: # define discriminators
......@@ -61,46 +72,33 @@ class MakeupModel(BaseModel):
param = paddle.load(vgg_weight_path)
vgg.load_dict(param)
self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
self.nets['netD_A'] = build_discriminator(discriminator)
self.nets['netD_B'] = build_discriminator(discriminator)
init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0)
init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0)
self.fake_A_pool = ImagePool(
cfg.dataset.train.pool_size
) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(
cfg.dataset.train.pool_size
) # create image buffer to store previously generated images
# create image buffer to store previously generated images
self.fake_A_pool = ImagePool(pool_size)
self.fake_B_pool = ImagePool(pool_size)
# define loss functions
self.criterionGAN = GANLoss(
cfg.model.gan_mode) #.to(self.device) # define GAN loss.
self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss()
self.criterionL1 = paddle.nn.L1Loss()
self.criterionL2 = paddle.nn.MSELoss()
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_DA'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD_A'].parameters())
self.optimizers['optimizer_DB'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD_B'].parameters())
def set_input(self, input):
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
if cycle_criterion:
self.cycle_criterion = build_criterion(cycle_criterion)
if idt_criterion:
self.idt_criterion = build_criterion(idt_criterion)
if l1_criterion:
self.l1_criterion = build_criterion(l1_criterion)
if l2_criterion:
self.l2_criterion = build_criterion(l2_criterion)
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
Args:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
self.real_A = paddle.to_tensor(input['image_A'])
self.real_B = paddle.to_tensor(input['image_B'])
......@@ -143,24 +141,6 @@ class MakeupModel(BaseModel):
self.visual_items['fake_A'] = self.fake_A
self.visual_items['rec_B'] = self.rec_B
def forward_test(self, input):
'''
not implement now
'''
return self.nets['netG'](input['image_A'], input['image_B'],
input['P_A'], input['P_B'],
input['consis_mask'], input['mask_A_aug'],
input['mask_B_aug'])
def test(self, input):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.no_grad():
return self.forward_test(input)
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
......@@ -174,10 +154,10 @@ class MakeupModel(BaseModel):
"""
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
loss_D_real = self.gan_criterion(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
loss_D_fake = self.gan_criterion(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
......@@ -200,24 +180,24 @@ class MakeupModel(BaseModel):
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.cfg.lambda_identity
lambda_A = self.cfg.lambda_A
lambda_B = self.cfg.lambda_B
lambda_A = self.lambda_a
lambda_B = self.lambda_b
lambda_vgg = 5e-3
# Identity loss
if lambda_idt > 0:
if self.idt_criterion:
self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A,
self.P_A, self.P_A,
self.c_m_idt_a, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_A) * lambda_A * lambda_idt
self.loss_idt_A = self.idt_criterion(self.idt_A,
self.real_A) * lambda_A
self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B,
self.P_B, self.P_B,
self.c_m_idt_b, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_B) * lambda_B * lambda_idt
self.loss_idt_B = self.idt_criterion(self.idt_B,
self.real_B) * lambda_B
# visual
self.visual_items['idt_A'] = self.idt_A
......@@ -227,17 +207,17 @@ class MakeupModel(BaseModel):
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_A),
True)
self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_A),
True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_B),
True)
self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_B),
True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A
self.loss_cycle_A = self.cycle_criterion(self.rec_A,
self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
self.loss_cycle_B = self.cycle_criterion(self.rec_B,
self.real_B) * lambda_B
self.losses['G_A_adv_loss'] = self.loss_G_A
self.losses['G_B_adv_loss'] = self.loss_G_B
......@@ -270,8 +250,10 @@ class MakeupModel(BaseModel):
fake_match_lip_B = fake_match_lip_B.unsqueeze(0)
fake_A_lip_masked = fake_A * mask_A_lip
fake_B_lip_masked = fake_B * mask_B_lip
g_A_lip_loss_his = self.criterionL1(fake_A_lip_masked, fake_match_lip_A)
g_B_lip_loss_his = self.criterionL1(fake_B_lip_masked, fake_match_lip_B)
g_A_lip_loss_his = self.l1_criterion(fake_A_lip_masked,
fake_match_lip_A)
g_B_lip_loss_his = self.l1_criterion(fake_B_lip_masked,
fake_match_lip_B)
#skin
mask_A_skin = self.mask_A_aug[:, 1].unsqueeze(1)
......@@ -294,10 +276,10 @@ class MakeupModel(BaseModel):
fake_match_skin_B = fake_match_skin_B.unsqueeze(0)
fake_A_skin_masked = fake_A * mask_A_skin
fake_B_skin_masked = fake_B * mask_B_skin
g_A_skin_loss_his = self.criterionL1(fake_A_skin_masked,
fake_match_skin_A)
g_B_skin_loss_his = self.criterionL1(fake_B_skin_masked,
fake_match_skin_B)
g_A_skin_loss_his = self.l1_criterion(fake_A_skin_masked,
fake_match_skin_A)
g_B_skin_loss_his = self.l1_criterion(fake_B_skin_masked,
fake_match_skin_B)
#eye
mask_A_eye = self.mask_A_aug[:, 2].unsqueeze(1)
......@@ -320,8 +302,10 @@ class MakeupModel(BaseModel):
fake_match_eye_B = fake_match_eye_B.unsqueeze(0)
fake_A_eye_masked = fake_A * mask_A_eye
fake_B_eye_masked = fake_B * mask_B_eye
g_A_eye_loss_his = self.criterionL1(fake_A_eye_masked, fake_match_eye_A)
g_B_eye_loss_his = self.criterionL1(fake_B_eye_masked, fake_match_eye_B)
g_A_eye_loss_his = self.l1_criterion(fake_A_eye_masked,
fake_match_eye_A)
g_B_eye_loss_his = self.l1_criterion(fake_B_eye_masked,
fake_match_eye_B)
self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his +
g_A_skin_loss_his * 0.1) * 0.1
......@@ -335,14 +319,14 @@ class MakeupModel(BaseModel):
vgg_s = self.vgg(self.real_A)
vgg_s.stop_gradient = True
vgg_fake_A = self.vgg(self.fake_A)
self.loss_A_vgg = self.criterionL2(vgg_fake_A,
vgg_s) * lambda_A * lambda_vgg
self.loss_A_vgg = self.l2_criterion(vgg_fake_A,
vgg_s) * lambda_A * lambda_vgg
vgg_r = self.vgg(self.real_B)
vgg_r.stop_gradient = True
vgg_fake_B = self.vgg(self.fake_B)
self.loss_B_vgg = self.criterionL2(vgg_fake_B,
vgg_r) * lambda_B * lambda_vgg
self.loss_B_vgg = self.l2_criterion(vgg_fake_B,
vgg_r) * lambda_B * lambda_vgg
self.loss_rec = (self.loss_cycle_A * 0.2 + self.loss_cycle_B * 0.2 +
self.loss_A_vgg + self.loss_B_vgg) * 0.5
......@@ -359,15 +343,14 @@ class MakeupModel(BaseModel):
(self.mask_A == 10), dtype='float32') + paddle.cast(
(self.mask_A == 8), dtype='float32')
mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1)
self.loss_G_bg_consis = self.criterionL1(
self.loss_G_bg_consis = self.l1_criterion(
self.real_A * mask_A_consis, self.fake_A * mask_A_consis) * 0.1
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_rec + self.loss_idt + self.loss_G_A_his + self.loss_G_B_his + self.loss_G_bg_consis
self.loss_G.backward()
def optimize_parameters(self):
def train_iter(self, optimizers=None):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute fake images and reconstruction images.
......@@ -392,5 +375,4 @@ class MakeupModel(BaseModel):
self.backward_D_B() # calculate graidents for D_B
self.optimizers['optimizer_DB'].minimize(
self.loss_D_B) #step() # update D_A and D_B's weights
self.optimizers['optimizer_DB'].clear_gradients(
) #zero_grad() # set D_A and D_B's gradients to zero
self.optimizers['optimizer_DB'].clear_gradients()
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -19,8 +19,6 @@ import paddle
from ..utils.logger import get_logger
logger = get_logger('init')
def _calculate_fan_in_and_fan_out(tensor):
dimensions = len(tensor.shape)
......@@ -313,5 +311,6 @@ def init_weights(net, init_type='normal', init_gain=0.02):
normal_(m.weight, 1.0, init_gain)
constant_(m.bias, 0.0)
logger = get_logger()
logger.debug('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
......@@ -12,4 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .optimizer import build_optimizer
from .lr_scheduler import CosineAnnealingRestartLR, LinearDecay
from .optimizer import *
from .builder import build_lr_scheduler
from .builder import build_optimizer
此差异已折叠。
......@@ -15,14 +15,9 @@
import copy
import paddle
from .lr_scheduler import build_lr_scheduler
from .builder import OPTIMIZERS
def build_optimizer(cfg, lr_scheduler, parameter_list=None):
cfg_copy = copy.deepcopy(cfg)
opt_name = cfg_copy.pop('name')
return getattr(paddle.optimizer, opt_name)(lr_scheduler,
parameters=parameter_list,
**cfg_copy)
OPTIMIZERS.register(paddle.optimizer.Adam)
OPTIMIZERS.register(paddle.optimizer.SGD)
OPTIMIZERS.register(paddle.optimizer.Momentum)
OPTIMIZERS.register(paddle.optimizer.RMSProp)
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册