未验证 提交 f7b53f07 编写于 作者: L LielinJiang 提交者: GitHub

adapt wgan (#128)

上级 7bba9f8d
......@@ -15,18 +15,20 @@ model:
n_layers: 3
input_nc: 1
norm_type: instance
gan_mode: wgan
n_critic: 5
gan_criterion:
name: GANLoss
gan_mode: wgan
params:
disc_iters: 5
visual_interval: 500
dataset:
train:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
dataset_name: MNIST
num_workers: 4
batch_size: 64
mode: train
return_cls: False
return_label: False
transforms:
- name: Normalize
mean: [127.5]
......@@ -34,28 +36,36 @@ dataset:
keys: [image]
test:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
dataset_name: MNIST
num_workers: 0
batch_size: 64
mode: test
return_label: False
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
return_cls: False
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
log_config:
interval: 100
......
......@@ -21,29 +21,38 @@ from .transforms.builder import build_transforms
@DATASETS.register()
class CommonVisionDataset(BaseDataset):
class CommonVisionDataset(paddle.io.Dataset):
"""
Dataset for using paddle vision default datasets
Dataset for using paddle vision default datasets, such as mnist, flowers.
"""
def __init__(self, cfg):
def __init__(self,
dataset_name,
transforms=None,
return_label=True,
params=None):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
dataset_name (str): return a dataset from paddle.vision.datasets by this option.
transforms (list[dict]): A sequence of data transforms config.
return_label (bool): whether to retuan a label of a sample.
params (dict): paramters of paddle.vision.datasets.
"""
super(CommonVisionDataset, self).__init__(cfg)
super(CommonVisionDataset, self).__init__()
dataset_cls = getattr(paddle.vision.datasets, cfg.pop('class_name'))
transform = build_transforms(cfg.pop('transforms', None))
self.return_cls = cfg.pop('return_cls', True)
dataset_cls = getattr(paddle.vision.datasets, dataset_name)
transform = build_transforms(transforms)
self.return_label = return_label
param_dict = {}
param_names = list(dataset_cls.__init__.__code__.co_varnames)
if 'transform' in param_names:
param_dict['transform'] = transform
for name in param_names:
if name in cfg:
param_dict[name] = cfg.get(name)
if params is not None:
for name in param_names:
if name in params:
param_dict[name] = params[name]
self.dataset = dataset_cls(**param_dict)
......@@ -53,7 +62,7 @@ class CommonVisionDataset(BaseDataset):
if isinstance(return_list, (tuple, list)):
if len(return_list) == 2:
return_dict['img'] = return_list[0]
if self.return_cls:
if self.return_label:
return_dict['class_id'] = np.asarray(return_list[1])
else:
return_dict['img'] = return_list[0]
......
......@@ -211,12 +211,24 @@ class Trainer:
current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()
for j in range(len(current_paths)):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
if len(current_visuals) > 0 and list(
current_visuals.values())[0].shape == 4:
num_samples = list(current_visuals.values())[0].shape[0]
else:
num_samples = 1
for j in range(num_samples):
if j < len(current_paths):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
else:
basename = '{:04d}_{:04d}'.format(i, j)
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]})
if len(img_tensor.shape) == 4:
visual_results.update({name: img_tensor[j]})
else:
visual_results.update({name: img_tensor})
self.visual('visual_test',
visual_results=visual_results,
......
......@@ -50,7 +50,7 @@ class BaseModel(ABC):
# save checkpoint (model.nets) \/
"""
def __init__(self):
def __init__(self, params=None):
"""Initialize the BaseModel class.
When creating your custom class, you need to implement your own initialization.
......@@ -62,7 +62,13 @@ class BaseModel(ABC):
-- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example.
Args:
params (dict): Hyper params for train or test. Default: None.
"""
self.params = params
self.is_train = True if self.params is None else self.params.get(
'is_train', True)
self.nets = OrderedDict()
self.optimizers = OrderedDict()
......@@ -149,7 +155,9 @@ class BaseModel(ABC):
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
if hasattr(self, 'image_paths'):
return self.image_paths
return []
def get_current_visuals(self):
"""Return visualization images."""
......
......@@ -19,7 +19,7 @@ from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .criterions.gan_loss import GANLoss
from .criterions.builder import build_criterion
from ..solver import build_optimizer
from ..modules.init import init_weights
......@@ -32,44 +32,46 @@ class GANModel(BaseModel):
vanilla GAN paper: https://arxiv.org/abs/1406.2661
"""
def __init__(self, cfg):
def __init__(self,
generator,
discriminator=None,
gan_criterion=None,
params=None):
"""Initialize the GAN Model class.
Parameters:
cfg (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
gan_criterion (dict): config of gan criterion.
params (dict): hyper params for train or test. Default: None.
"""
super(GANModel, self).__init__(cfg)
self.step = 0
self.n_critic = cfg.model.get('n_critic', 1)
self.visual_interval = cfg.log_config.visiual_interval
self.samples_every_row = cfg.model.get('samples_every_row', 8)
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator)
super(GANModel, self).__init__(params)
self.iter = 0
self.disc_iters = 1 if self.params is None else self.params.get(
'disc_iters', 1)
self.disc_start_iters = (0 if self.params is None else self.params.get(
'disc_start_iters', 0))
self.samples_every_row = (8 if self.params is None else self.params.get(
'samples_every_row', 8))
self.visual_interval = (500 if self.params is None else self.params.get(
'visual_interval', 500))
# define generator
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'])
# define a discriminator
if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD'])
if discriminator is not None:
self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD'])
if self.is_train:
self.losses = {}
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode)
# build optimizers
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
def set_input(self, input):
if gan_criterion:
self.criterionGAN = build_criterion(gan_criterion)
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
......@@ -131,7 +133,7 @@ class GANModel(BaseModel):
self.loss_D_real = self.criterionGAN(pred_real, True, True)
# combine loss and calculate gradients
if self.cfg.model.gan_mode in ['vanilla', 'lsgan']:
if self.criterionGAN.gan_mode in ['vanilla', 'lsgan']:
self.loss_D = self.loss_D + (self.loss_D_fake +
self.loss_D_real) * 0.5
else:
......@@ -159,34 +161,34 @@ class GANModel(BaseModel):
self.losses['G_adv_loss'] = self.loss_G_GAN
def optimize_parameters(self):
def train_iter(self, optimizers=None):
# compute fake images: G(imgs)
self.forward()
# update D
self.set_requires_grad(self.nets['netD'], True)
self.optimizers['optimizer_D'].clear_grad()
optimizers['optimizer_D'].clear_grad()
self.backward_D()
self.optimizers['optimizer_D'].step()
optimizers['optimizer_D'].step()
self.set_requires_grad(self.nets['netD'], False)
# weight clip
if self.cfg.model.gan_mode == 'wgan':
if self.criterionGAN.gan_mode == 'wgan':
with paddle.no_grad():
for p in self.nets['netD'].parameters():
p[:] = p.clip(-0.01, 0.01)
if self.step % self.n_critic == 0:
if self.iter > self.disc_start_iters and self.iter % self.disc_iters == 0:
# update G
self.optimizers['optimizer_G'].clear_grad()
optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
optimizers['optimizer_G'].step()
if self.step % self.visual_interval == 0:
if self.iter % self.visual_interval == 0:
with paddle.no_grad():
self.visual_items['fixed_generated_imgs'] = make_grid(
self.nets['netG'](*self.G_fixed_inputs),
self.samples_every_row)
self.step += 1
self.iter += 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册