未验证 提交 f7b53f07 编写于 作者: L LielinJiang 提交者: GitHub

adapt wgan (#128)

上级 7bba9f8d
...@@ -15,18 +15,20 @@ model: ...@@ -15,18 +15,20 @@ model:
n_layers: 3 n_layers: 3
input_nc: 1 input_nc: 1
norm_type: instance norm_type: instance
gan_mode: wgan gan_criterion:
n_critic: 5 name: GANLoss
gan_mode: wgan
params:
disc_iters: 5
visual_interval: 500
dataset: dataset:
train: train:
name: CommonVisionDataset name: CommonVisionDataset
class_name: MNIST dataset_name: MNIST
dataroot: None
num_workers: 4 num_workers: 4
batch_size: 64 batch_size: 64
mode: train return_label: False
return_cls: False
transforms: transforms:
- name: Normalize - name: Normalize
mean: [127.5] mean: [127.5]
...@@ -34,28 +36,36 @@ dataset: ...@@ -34,28 +36,36 @@ dataset:
keys: [image] keys: [image]
test: test:
name: CommonVisionDataset name: CommonVisionDataset
class_name: MNIST dataset_name: MNIST
dataroot: None
num_workers: 0 num_workers: 0
batch_size: 64 batch_size: 64
mode: test return_label: False
transforms: transforms:
- name: Normalize - name: Normalize
mean: [127.5] mean: [127.5]
std: [127.5] std: [127.5]
keys: [image] keys: [image]
return_cls: False
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
...@@ -21,29 +21,38 @@ from .transforms.builder import build_transforms ...@@ -21,29 +21,38 @@ from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class CommonVisionDataset(BaseDataset): class CommonVisionDataset(paddle.io.Dataset):
""" """
Dataset for using paddle vision default datasets Dataset for using paddle vision default datasets, such as mnist, flowers.
""" """
def __init__(self, cfg): def __init__(self,
dataset_name,
transforms=None,
return_label=True,
params=None):
"""Initialize this dataset class. """Initialize this dataset class.
Args: Args:
cfg (dict) -- stores all the experiment flags dataset_name (str): return a dataset from paddle.vision.datasets by this option.
transforms (list[dict]): A sequence of data transforms config.
return_label (bool): whether to retuan a label of a sample.
params (dict): paramters of paddle.vision.datasets.
""" """
super(CommonVisionDataset, self).__init__(cfg) super(CommonVisionDataset, self).__init__()
dataset_cls = getattr(paddle.vision.datasets, cfg.pop('class_name')) dataset_cls = getattr(paddle.vision.datasets, dataset_name)
transform = build_transforms(cfg.pop('transforms', None)) transform = build_transforms(transforms)
self.return_cls = cfg.pop('return_cls', True) self.return_label = return_label
param_dict = {} param_dict = {}
param_names = list(dataset_cls.__init__.__code__.co_varnames) param_names = list(dataset_cls.__init__.__code__.co_varnames)
if 'transform' in param_names: if 'transform' in param_names:
param_dict['transform'] = transform param_dict['transform'] = transform
for name in param_names:
if name in cfg: if params is not None:
param_dict[name] = cfg.get(name) for name in param_names:
if name in params:
param_dict[name] = params[name]
self.dataset = dataset_cls(**param_dict) self.dataset = dataset_cls(**param_dict)
...@@ -53,7 +62,7 @@ class CommonVisionDataset(BaseDataset): ...@@ -53,7 +62,7 @@ class CommonVisionDataset(BaseDataset):
if isinstance(return_list, (tuple, list)): if isinstance(return_list, (tuple, list)):
if len(return_list) == 2: if len(return_list) == 2:
return_dict['img'] = return_list[0] return_dict['img'] = return_list[0]
if self.return_cls: if self.return_label:
return_dict['class_id'] = np.asarray(return_list[1]) return_dict['class_id'] = np.asarray(return_list[1])
else: else:
return_dict['img'] = return_list[0] return_dict['img'] = return_list[0]
......
...@@ -211,12 +211,24 @@ class Trainer: ...@@ -211,12 +211,24 @@ class Trainer:
current_paths = self.model.get_image_paths() current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals() current_visuals = self.model.get_current_visuals()
for j in range(len(current_paths)): if len(current_visuals) > 0 and list(
short_path = os.path.basename(current_paths[j]) current_visuals.values())[0].shape == 4:
basename = os.path.splitext(short_path)[0] num_samples = list(current_visuals.values())[0].shape[0]
else:
num_samples = 1
for j in range(num_samples):
if j < len(current_paths):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
else:
basename = '{:04d}_{:04d}'.format(i, j)
for k, img_tensor in current_visuals.items(): for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k) name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]}) if len(img_tensor.shape) == 4:
visual_results.update({name: img_tensor[j]})
else:
visual_results.update({name: img_tensor})
self.visual('visual_test', self.visual('visual_test',
visual_results=visual_results, visual_results=visual_results,
......
...@@ -50,7 +50,7 @@ class BaseModel(ABC): ...@@ -50,7 +50,7 @@ class BaseModel(ABC):
# save checkpoint (model.nets) \/ # save checkpoint (model.nets) \/
""" """
def __init__(self): def __init__(self, params=None):
"""Initialize the BaseModel class. """Initialize the BaseModel class.
When creating your custom class, you need to implement your own initialization. When creating your custom class, you need to implement your own initialization.
...@@ -62,7 +62,13 @@ class BaseModel(ABC): ...@@ -62,7 +62,13 @@ class BaseModel(ABC):
-- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network. -- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group them. If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example. See cycle_gan_model.py for an example.
Args:
params (dict): Hyper params for train or test. Default: None.
""" """
self.params = params
self.is_train = True if self.params is None else self.params.get(
'is_train', True)
self.nets = OrderedDict() self.nets = OrderedDict()
self.optimizers = OrderedDict() self.optimizers = OrderedDict()
...@@ -149,7 +155,9 @@ class BaseModel(ABC): ...@@ -149,7 +155,9 @@ class BaseModel(ABC):
def get_image_paths(self): def get_image_paths(self):
""" Return image paths that are used to load current data""" """ Return image paths that are used to load current data"""
return self.image_paths if hasattr(self, 'image_paths'):
return self.image_paths
return []
def get_current_visuals(self): def get_current_visuals(self):
"""Return visualization images.""" """Return visualization images."""
......
...@@ -19,7 +19,7 @@ from .base_model import BaseModel ...@@ -19,7 +19,7 @@ from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .criterions.gan_loss import GANLoss from .criterions.builder import build_criterion
from ..solver import build_optimizer from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
...@@ -32,44 +32,46 @@ class GANModel(BaseModel): ...@@ -32,44 +32,46 @@ class GANModel(BaseModel):
vanilla GAN paper: https://arxiv.org/abs/1406.2661 vanilla GAN paper: https://arxiv.org/abs/1406.2661
""" """
def __init__(self, cfg): def __init__(self,
generator,
discriminator=None,
gan_criterion=None,
params=None):
"""Initialize the GAN Model class. """Initialize the GAN Model class.
Parameters: Args:
cfg (config dict)-- stores all the experiment flags; needs to be a subclass of Dict generator (dict): config of generator.
discriminator (dict): config of discriminator.
gan_criterion (dict): config of gan criterion.
params (dict): hyper params for train or test. Default: None.
""" """
super(GANModel, self).__init__(cfg) super(GANModel, self).__init__(params)
self.step = 0 self.iter = 0
self.n_critic = cfg.model.get('n_critic', 1)
self.visual_interval = cfg.log_config.visiual_interval self.disc_iters = 1 if self.params is None else self.params.get(
self.samples_every_row = cfg.model.get('samples_every_row', 8) 'disc_iters', 1)
self.disc_start_iters = (0 if self.params is None else self.params.get(
# define networks (both generator and discriminator) 'disc_start_iters', 0))
self.nets['netG'] = build_generator(cfg.model.generator) self.samples_every_row = (8 if self.params is None else self.params.get(
'samples_every_row', 8))
self.visual_interval = (500 if self.params is None else self.params.get(
'visual_interval', 500))
# define generator
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG']) init_weights(self.nets['netG'])
# define a discriminator # define a discriminator
if self.is_train: if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator) if discriminator is not None:
init_weights(self.nets['netD']) self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD'])
if self.is_train:
self.losses = {}
# define loss functions # define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode) if gan_criterion:
self.criterionGAN = build_criterion(gan_criterion)
# build optimizers
self.build_lr_scheduler() def setup_input(self, input):
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters: Parameters:
...@@ -131,7 +133,7 @@ class GANModel(BaseModel): ...@@ -131,7 +133,7 @@ class GANModel(BaseModel):
self.loss_D_real = self.criterionGAN(pred_real, True, True) self.loss_D_real = self.criterionGAN(pred_real, True, True)
# combine loss and calculate gradients # combine loss and calculate gradients
if self.cfg.model.gan_mode in ['vanilla', 'lsgan']: if self.criterionGAN.gan_mode in ['vanilla', 'lsgan']:
self.loss_D = self.loss_D + (self.loss_D_fake + self.loss_D = self.loss_D + (self.loss_D_fake +
self.loss_D_real) * 0.5 self.loss_D_real) * 0.5
else: else:
...@@ -159,34 +161,34 @@ class GANModel(BaseModel): ...@@ -159,34 +161,34 @@ class GANModel(BaseModel):
self.losses['G_adv_loss'] = self.loss_G_GAN self.losses['G_adv_loss'] = self.loss_G_GAN
def optimize_parameters(self): def train_iter(self, optimizers=None):
# compute fake images: G(imgs) # compute fake images: G(imgs)
self.forward() self.forward()
# update D # update D
self.set_requires_grad(self.nets['netD'], True) self.set_requires_grad(self.nets['netD'], True)
self.optimizers['optimizer_D'].clear_grad() optimizers['optimizer_D'].clear_grad()
self.backward_D() self.backward_D()
self.optimizers['optimizer_D'].step() optimizers['optimizer_D'].step()
self.set_requires_grad(self.nets['netD'], False) self.set_requires_grad(self.nets['netD'], False)
# weight clip # weight clip
if self.cfg.model.gan_mode == 'wgan': if self.criterionGAN.gan_mode == 'wgan':
with paddle.no_grad(): with paddle.no_grad():
for p in self.nets['netD'].parameters(): for p in self.nets['netD'].parameters():
p[:] = p.clip(-0.01, 0.01) p[:] = p.clip(-0.01, 0.01)
if self.step % self.n_critic == 0: if self.iter > self.disc_start_iters and self.iter % self.disc_iters == 0:
# update G # update G
self.optimizers['optimizer_G'].clear_grad() optimizers['optimizer_G'].clear_grad()
self.backward_G() self.backward_G()
self.optimizers['optimizer_G'].step() optimizers['optimizer_G'].step()
if self.step % self.visual_interval == 0: if self.iter % self.visual_interval == 0:
with paddle.no_grad(): with paddle.no_grad():
self.visual_items['fixed_generated_imgs'] = make_grid( self.visual_items['fixed_generated_imgs'] = make_grid(
self.nets['netG'](*self.G_fixed_inputs), self.nets['netG'](*self.G_fixed_inputs),
self.samples_every_row) self.samples_every_row)
self.step += 1 self.iter += 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册