未验证 提交 b2b881ed 编写于 作者: H Hecong Wu 提交者: GitHub

add some custom gan models and datasets (#95)

* add some custom gan models
上级 fdbc6aee
epochs: 200
output_dir: output_dir
model:
name: GANModel
generator:
name: ConditionalDeepConvGenerator
latent_dim: 128
output_nc: 1
size: 28
ngf: 64
n_class: 10
discriminator:
name: NLayerDiscriminatorWithClassification
ndf: 16
n_layers: 3
input_nc: 1
norm_type: batch
n_class: 10
use_sigmoid: True
gan_mode: vanilla
dataset:
train:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
num_workers: 4
batch_size: 64
mode: train
return_cls: True
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
test:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
num_workers: 0
batch_size: 64
mode: test
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
return_cls: True
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5
epochs: 200
output_dir: output_dir
model:
name: GANModel
generator:
name: DeepConvGenerator
latent_dim: 128
output_nc: 1
size: 28
ngf: 64
discriminator:
name: NLayerDiscriminator
ndf: 16
n_layers: 3
input_nc: 1
norm_type: instance
gan_mode: wgan
n_critic: 5
dataset:
train:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
num_workers: 4
batch_size: 64
mode: train
return_cls: False
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
test:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
num_workers: 0
batch_size: 64
mode: test
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
return_cls: False
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5
......@@ -17,4 +17,5 @@ from .single_dataset import SingleDataset
from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset
from .makeup_dataset import MakeupDataset
from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .builder import DATASETS
from .base_dataset import BaseDataset
from .transforms.builder import build_transforms
@DATASETS.register()
class CommonVisionDataset(BaseDataset):
"""
Dataset for using paddle vision default datasets
"""
def __init__(self, cfg):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
super(CommonVisionDataset, self).__init__(cfg)
dataset_cls = getattr(paddle.vision.datasets, cfg.pop('class_name'))
transform = build_transforms(cfg.pop('transforms', None))
self.return_cls = cfg.pop('return_cls', True)
param_dict = {}
param_names = list(dataset_cls.__init__.__code__.co_varnames)
if 'transform' in param_names:
param_dict['transform'] = transform
for name in param_names:
if name in cfg:
param_dict[name] = cfg.get(name)
self.dataset = dataset_cls(**param_dict)
def __getitem__(self, index):
return_dict = {}
return_list = self.dataset[index]
if isinstance(return_list, (tuple, list)):
if len(return_list) == 2:
return_dict['img'] = return_list[0]
if self.return_cls:
return_dict['class_id'] = np.asarray(return_list[1])
else:
return_dict['img'] = return_list[0]
else:
return_dict['img'] = return_list
return return_dict
def __len__(self):
return len(self.dataset)
......@@ -57,7 +57,7 @@ class SingleDataset(BaseDataset):
return len(self.A_paths)
def get_path_by_indexs(self, indexs):
if isinstance(indexs, paddle.Variable):
if isinstance(indexs, paddle.Tensor):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from .base_model import BaseModel
from .gan_model import GANModel
from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .nlayers import NLayerDiscriminator
from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification
from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator
from .discriminator_animegan import AnimeDiscriminator
......@@ -27,7 +27,7 @@ from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class NLayerDiscriminator(nn.Layer):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance', use_sigmoid=False):
"""Construct a PatchGAN discriminator
Parameters:
......@@ -35,6 +35,7 @@ class NLayerDiscriminator(nn.Layer):
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_type (str) -- normalization layer type
use_sigmoid (bool) -- whether use sigmoid at last
"""
super(NLayerDiscriminator, self).__init__()
norm_layer = build_norm_layer(norm_type)
......@@ -139,7 +140,28 @@ class NLayerDiscriminator(nn.Layer):
] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
self.final_act = F.sigmoid if use_sigmoid else (lambda x:x)
def forward(self, input):
"""Standard forward."""
return self.model(input)
return self.final_act(self.model(input))
@DISCRIMINATORS.register()
class NLayerDiscriminatorWithClassification(NLayerDiscriminator):
def __init__(self, input_nc, n_class=10, **kwargs):
input_nc = input_nc + n_class
super(NLayerDiscriminatorWithClassification, self).__init__(input_nc, **kwargs)
self.n_class = n_class
def forward(self, x, class_id):
if self.n_class > 0:
class_id = (class_id % self.n_class).detach()
class_id = F.one_hot(class_id, self.n_class).astype('float32')
class_id = class_id.reshape([x.shape[0], -1, 1, 1])
class_id = class_id.tile([1,1,*x.shape[2:]])
x = paddle.concat([x, class_id], 1)
return super(NLayerDiscriminatorWithClassification, self).forward(x)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from ..solver import build_optimizer
from ..modules.init import init_weights
from ..utils.visual import make_grid
@MODELS.register()
class GANModel(BaseModel):
""" This class implements the vanilla GAN model with some tricks.
vanilla GAN paper: https://arxiv.org/abs/1406.2661
"""
def __init__(self, cfg):
"""Initialize the GAN Model class.
Parameters:
cfg (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
"""
super(GANModel, self).__init__(cfg)
self.step = 0
self.n_critic = cfg.model.get('n_critic', 1)
self.visual_interval = cfg.log_config.visiual_interval
self.samples_every_row = cfg.model.get('samples_every_row', 8)
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator)
init_weights(self.nets['netG'])
# define a discriminator
if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD'])
if self.is_train:
self.losses = {}
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode)
# build optimizers
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (list): include the data itself and its metadata information.
"""
if isinstance(input, (list, tuple)):
input = input[0]
if not isinstance(input, dict):
input = {'img': input}
self.D_real_inputs = [paddle.to_tensor(input['img'])]
if 'class_id' in input: # n class input
self.n_class = self.nets['netG'].n_class
self.D_real_inputs += [paddle.to_tensor(input['class_id'], dtype='int64')]
else:
self.n_class = 0
batch_size = self.D_real_inputs[0].shape[0]
self.G_inputs = self.nets['netG'].random_inputs(batch_size)
if not isinstance(self.G_inputs, (list, tuple)):
self.G_inputs = [self.G_inputs]
if not hasattr(self, 'G_fixed_inputs'):
self.G_fixed_inputs = [t for t in self.G_inputs]
if self.n_class > 0:
rows_num = (batch_size - 1) // self.samples_every_row + 1
class_ids = paddle.randint(0, self.n_class, [rows_num, 1])
class_ids = class_ids.tile([1, self.samples_every_row])
class_ids = class_ids.reshape([-1,])[:batch_size].detach()
self.G_fixed_inputs[1] = class_ids.detach()
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_imgs = self.nets['netG'](*self.G_inputs) # G(img, class_id)
# put items to visual dict
self.visual_items['fake_imgs'] = make_grid(self.fake_imgs, self.samples_every_row).detach()
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_imgs
# use conditional GANs; we need to feed both input and output to the discriminator
self.loss_D = 0
self.D_fake_inputs = [self.fake_imgs.detach()]
if len(self.G_inputs) > 1 and self.G_inputs[1] is not None:
self.D_fake_inputs += [self.G_inputs[1]]
pred_fake = self.nets['netD'](*self.D_fake_inputs)
# Real
real_imgs = self.D_real_inputs[0]
self.visual_items['real_imgs'] = make_grid(real_imgs, self.samples_every_row).detach()
pred_real = self.nets['netD'](*self.D_real_inputs)
self.loss_D_fake = self.criterionGAN(pred_fake, False, True)
self.loss_D_real = self.criterionGAN(pred_real, True, True)
# combine loss and calculate gradients
if self.cfg.model.gan_mode in ['vanilla', 'lsgan']:
self.loss_D = self.loss_D + (self.loss_D_fake + self.loss_D_real) * 0.5
else:
self.loss_D = self.loss_D + self.loss_D_fake + self.loss_D_real
self.loss_D.backward()
self.losses['D_fake_loss'] = self.loss_D_fake
self.losses['D_real_loss'] = self.loss_D_real
def backward_G(self):
"""Calculate GAN loss for the generator"""
# First, G(imgs) should fake the discriminator
self.D_fake_inputs = [self.fake_imgs]
if len(self.G_inputs) > 1 and self.G_inputs[1] is not None:
self.D_fake_inputs += [self.G_inputs[1]]
pred_fake = self.nets['netD'](*self.D_fake_inputs)
self.loss_G_GAN = self.criterionGAN(pred_fake, True, False)
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN
self.loss_G.backward()
self.losses['G_adv_loss'] = self.loss_G_GAN
def optimize_parameters(self):
# compute fake images: G(imgs)
self.forward()
# update D
self.set_requires_grad(self.nets['netD'], True)
self.optimizers['optimizer_D'].clear_grad()
self.backward_D()
self.optimizers['optimizer_D'].step()
self.set_requires_grad(self.nets['netD'], False)
# weight clip
if self.cfg.model.gan_mode == 'wgan':
with paddle.no_grad():
for p in self.nets['netD'].parameters():
p[:] = p.clip(-0.01, 0.01)
if self.step % self.n_critic == 0:
# update G
self.optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
if self.step % self.visual_interval == 0:
with paddle.no_grad():
self.visual_items['fixed_generated_imgs'] = make_grid(
self.nets['netG'](*self.G_fixed_inputs), self.samples_every_row
)
self.step += 1
......@@ -16,6 +16,7 @@ from .resnet import ResnetGenerator
from .unet import UnetGenerator
from .rrdb_net import RRDBNet
from .makeup import GeneratorPSGANAttention
from .deep_conv import DeepConvGenerator, ConditionalDeepConvGenerator
from .resnet_ugatit import ResnetUGATITGenerator
from .dcgenerator import DCGenerator
from .generater_animegan import AnimeGenerator, AnimeGeneratorLite
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
@GENERATORS.register()
class DeepConvGenerator(nn.Layer):
"""Create a Deep Convolutional generator"""
def __init__(self, latent_dim, output_nc, size=64, ngf=64):
"""Construct a Deep Convolutional generator
Args:
latent_dim (int) -- the number of latent dimension
output_nc (int) -- the number of channels in output images
size (int) -- size of output tensor
ngf (int) -- the number of filters in the last conv layer
Refer to https://arxiv.org/abs/1511.06434
"""
super(DeepConvGenerator, self).__init__()
self.latent_dim = latent_dim
self.ngf = ngf
self.init_size = size // 4
self.l1 = nn.Sequential(nn.Linear(latent_dim, ngf*2 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2D(ngf*2),
nn.Upsample(scale_factor=2),
nn.Conv2D(ngf*2, ngf*2, 3, stride=1, padding=1),
nn.BatchNorm2D(ngf*2, 0.2),
nn.LeakyReLU(0.2),
nn.Upsample(scale_factor=2),
nn.Conv2D(ngf*2, ngf, 3, stride=1, padding=1),
nn.BatchNorm2D(ngf, 0.2),
nn.LeakyReLU(0.2),
nn.Conv2D(ngf, output_nc, 3, stride=1, padding=1),
nn.Tanh(),
)
def random_inputs(self, batch_size):
return paddle.randn([batch_size, self.latent_dim])
def forward(self, z):
out = self.l1(z)
out = out.reshape([out.shape[0], self.ngf * 2, self.init_size, self.init_size])
img = self.conv_blocks(out)
return img
@GENERATORS.register()
class ConditionalDeepConvGenerator(DeepConvGenerator):
def __init__(self, latent_dim, output_nc, n_class=10, **kwargs):
super(ConditionalDeepConvGenerator, self).__init__(latent_dim + n_class, output_nc, **kwargs)
self.n_class = n_class
self.latent_dim = latent_dim
def random_inputs(self, batch_size):
return_list = [super(ConditionalDeepConvGenerator, self).random_inputs(batch_size)]
class_id = paddle.randint(0, self.n_class, [batch_size])
return return_list + [class_id]
def forward(self, x, class_id=None):
if self.n_class > 0:
class_id = (class_id % self.n_class).detach()
class_id = F.one_hot(class_id, self.n_class).astype('float32')
class_id = class_id.reshape([x.shape[0], -1])
x = paddle.concat([x, class_id], 1)
return super(ConditionalDeepConvGenerator, self).forward(x)
......@@ -16,6 +16,7 @@ import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class GANLoss(nn.Layer):
......@@ -44,7 +45,7 @@ class GANLoss(nn.Layer):
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
elif gan_mode in ['wgan', 'wgangp', 'hinge', 'logistic']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
......@@ -77,12 +78,13 @@ class GANLoss(nn.Layer):
# target_tensor.stop_gradient = True
return target_tensor
def __call__(self, prediction, target_is_real):
def __call__(self, prediction, target_is_real, is_updating_D=None):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
is_updating_D (bool) - - if we are in updating D step or not
Returns:
the calculated loss.
......@@ -90,9 +92,20 @@ class GANLoss(nn.Layer):
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
elif self.gan_mode.find('wgan') != -1:
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
elif self.gan_mode == 'hinge':
if target_is_real:
loss = F.relu(1 - prediction) if is_updating_D else -prediction
else:
loss = F.relu(1 + prediction) if is_updating_D else prediction
loss = loss.mean()
elif self.gan_mode == 'logistic':
if target_is_real:
loss = F.softplus(-prediction).mean()
else:
loss = F.softplus(prediction).mean()
return loss
......@@ -12,9 +12,86 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
import numpy as np
from PIL import Image
irange = range
def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False):
"""Make a grid of images.
Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
or a list of images all of the same size.
nrow (int, optional): Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each (bool, optional): If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
"""
if not (isinstance(tensor, paddle.Tensor) or
(isinstance(tensor, list) and all(isinstance(tensor, t) for t in tensor))):
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
tensor = paddle.stack(tensor, 0)
if tensor.dim() == 2: # single image H x W
tensor = tensor.unsqueeze(0)
if tensor.dim() == 3: # single image
if tensor.shape[0] == 1: # if single-channel, convert to 3-channel
tensor = paddle.concat([tensor, tensor, tensor], 0)
tensor = tensor.unsqueeze(0)
if tensor.dim() == 4 and tensor.shape[1] == 1: # single-channel images
tensor = paddle.concat([tensor, tensor, tensor], 1)
if normalize is True:
tensor = tensor.astype(tensor.dtype) # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, min, max):
img[:] = img.clip(min=min, max=max)
img[:] = (img - min) / (max - min + 1e-5)
def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, range)
else:
norm_range(tensor, range)
if tensor.shape[0] == 1:
return tensor.squeeze(0)
# make the mini-batch of images into a grid
nmaps = tensor.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[2]), int(tensor.shape[3])
num_channels = tensor.shape[1]
canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps), dtype=tensor.dtype)
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
if k >= nmaps:
break
canvas[:, y * height:(y + 1) * height, x * width:(x + 1) * width] = tensor[k]
k = k + 1
return canvas
def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册