diff --git a/configs/wgan_mnist.yaml b/configs/wgan_mnist.yaml index 4db41860e908bd04fbb20b6024e708a756a51287..930628b9f5cba98600158f3f44e7223ddb424100 100644 --- a/configs/wgan_mnist.yaml +++ b/configs/wgan_mnist.yaml @@ -15,18 +15,20 @@ model: n_layers: 3 input_nc: 1 norm_type: instance - gan_mode: wgan - n_critic: 5 + gan_criterion: + name: GANLoss + gan_mode: wgan + params: + disc_iters: 5 + visual_interval: 500 dataset: train: name: CommonVisionDataset - class_name: MNIST - dataroot: None + dataset_name: MNIST num_workers: 4 batch_size: 64 - mode: train - return_cls: False + return_label: False transforms: - name: Normalize mean: [127.5] @@ -34,28 +36,36 @@ dataset: keys: [image] test: name: CommonVisionDataset - class_name: MNIST - dataroot: None + dataset_name: MNIST num_workers: 0 batch_size: 64 - mode: test + return_label: False transforms: - name: Normalize mean: [127.5] std: [127.5] keys: [image] - return_cls: False -optimizer: - name: Adam - beta1: 0.5 - lr_scheduler: - name: linear + name: LinearDecay learning_rate: 0.0002 start_epoch: 100 decay_epochs: 100 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + optimizer_G: + name: Adam + net_names: + - netG + beta1: 0.5 + optimizer_D: + name: Adam + net_names: + - netD + beta1: 0.5 log_config: interval: 100 diff --git a/ppgan/datasets/common_vision_dataset.py b/ppgan/datasets/common_vision_dataset.py index f6bab62aeb0b20221132b9bb2c16aa4c7510f46a..2e69104603defab1c03705b02043bcb535f18079 100644 --- a/ppgan/datasets/common_vision_dataset.py +++ b/ppgan/datasets/common_vision_dataset.py @@ -21,29 +21,38 @@ from .transforms.builder import build_transforms @DATASETS.register() -class CommonVisionDataset(BaseDataset): +class CommonVisionDataset(paddle.io.Dataset): """ - Dataset for using paddle vision default datasets + Dataset for using paddle vision default datasets, such as mnist, flowers. """ - def __init__(self, cfg): + def __init__(self, + dataset_name, + transforms=None, + return_label=True, + params=None): """Initialize this dataset class. Args: - cfg (dict) -- stores all the experiment flags + dataset_name (str): return a dataset from paddle.vision.datasets by this option. + transforms (list[dict]): A sequence of data transforms config. + return_label (bool): whether to retuan a label of a sample. + params (dict): paramters of paddle.vision.datasets. """ - super(CommonVisionDataset, self).__init__(cfg) + super(CommonVisionDataset, self).__init__() - dataset_cls = getattr(paddle.vision.datasets, cfg.pop('class_name')) - transform = build_transforms(cfg.pop('transforms', None)) - self.return_cls = cfg.pop('return_cls', True) + dataset_cls = getattr(paddle.vision.datasets, dataset_name) + transform = build_transforms(transforms) + self.return_label = return_label param_dict = {} param_names = list(dataset_cls.__init__.__code__.co_varnames) if 'transform' in param_names: param_dict['transform'] = transform - for name in param_names: - if name in cfg: - param_dict[name] = cfg.get(name) + + if params is not None: + for name in param_names: + if name in params: + param_dict[name] = params[name] self.dataset = dataset_cls(**param_dict) @@ -53,7 +62,7 @@ class CommonVisionDataset(BaseDataset): if isinstance(return_list, (tuple, list)): if len(return_list) == 2: return_dict['img'] = return_list[0] - if self.return_cls: + if self.return_label: return_dict['class_id'] = np.asarray(return_list[1]) else: return_dict['img'] = return_list[0] diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 4812a125ee2a4ec4e32d051d2d9be3fbaad8e928..6d6b22835471883ca654962b627deb6c7ca75169 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -211,12 +211,24 @@ class Trainer: current_paths = self.model.get_image_paths() current_visuals = self.model.get_current_visuals() - for j in range(len(current_paths)): - short_path = os.path.basename(current_paths[j]) - basename = os.path.splitext(short_path)[0] + if len(current_visuals) > 0 and list( + current_visuals.values())[0].shape == 4: + num_samples = list(current_visuals.values())[0].shape[0] + else: + num_samples = 1 + + for j in range(num_samples): + if j < len(current_paths): + short_path = os.path.basename(current_paths[j]) + basename = os.path.splitext(short_path)[0] + else: + basename = '{:04d}_{:04d}'.format(i, j) for k, img_tensor in current_visuals.items(): name = '%s_%s' % (basename, k) - visual_results.update({name: img_tensor[j]}) + if len(img_tensor.shape) == 4: + visual_results.update({name: img_tensor[j]}) + else: + visual_results.update({name: img_tensor}) self.visual('visual_test', visual_results=visual_results, diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py index 6b92db7f3d0d0aab4078762e212f554713620a52..e70524715328491cef9430823fc97b6514add7fa 100644 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -50,7 +50,7 @@ class BaseModel(ABC): # save checkpoint (model.nets) \/ """ - def __init__(self): + def __init__(self, params=None): """Initialize the BaseModel class. When creating your custom class, you need to implement your own initialization. @@ -62,7 +62,13 @@ class BaseModel(ABC): -- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + + Args: + params (dict): Hyper params for train or test. Default: None. """ + self.params = params + self.is_train = True if self.params is None else self.params.get( + 'is_train', True) self.nets = OrderedDict() self.optimizers = OrderedDict() @@ -149,7 +155,9 @@ class BaseModel(ABC): def get_image_paths(self): """ Return image paths that are used to load current data""" - return self.image_paths + if hasattr(self, 'image_paths'): + return self.image_paths + return [] def get_current_visuals(self): """Return visualization images.""" diff --git a/ppgan/models/gan_model.py b/ppgan/models/gan_model.py index e788d6fc40c05e4bbeb4fc0fedc53e762da8ca92..6c82488bd0fe495b3ee0fa13ea8c28795b559261 100644 --- a/ppgan/models/gan_model.py +++ b/ppgan/models/gan_model.py @@ -19,7 +19,7 @@ from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator -from .criterions.gan_loss import GANLoss +from .criterions.builder import build_criterion from ..solver import build_optimizer from ..modules.init import init_weights @@ -32,44 +32,46 @@ class GANModel(BaseModel): vanilla GAN paper: https://arxiv.org/abs/1406.2661 """ - def __init__(self, cfg): + def __init__(self, + generator, + discriminator=None, + gan_criterion=None, + params=None): """Initialize the GAN Model class. - Parameters: - cfg (config dict)-- stores all the experiment flags; needs to be a subclass of Dict + Args: + generator (dict): config of generator. + discriminator (dict): config of discriminator. + gan_criterion (dict): config of gan criterion. + params (dict): hyper params for train or test. Default: None. """ - super(GANModel, self).__init__(cfg) - self.step = 0 - self.n_critic = cfg.model.get('n_critic', 1) - self.visual_interval = cfg.log_config.visiual_interval - self.samples_every_row = cfg.model.get('samples_every_row', 8) - - # define networks (both generator and discriminator) - self.nets['netG'] = build_generator(cfg.model.generator) + super(GANModel, self).__init__(params) + self.iter = 0 + + self.disc_iters = 1 if self.params is None else self.params.get( + 'disc_iters', 1) + self.disc_start_iters = (0 if self.params is None else self.params.get( + 'disc_start_iters', 0)) + self.samples_every_row = (8 if self.params is None else self.params.get( + 'samples_every_row', 8)) + self.visual_interval = (500 if self.params is None else self.params.get( + 'visual_interval', 500)) + + # define generator + self.nets['netG'] = build_generator(generator) init_weights(self.nets['netG']) # define a discriminator if self.is_train: - self.nets['netD'] = build_discriminator(cfg.model.discriminator) - init_weights(self.nets['netD']) + if discriminator is not None: + self.nets['netD'] = build_discriminator(discriminator) + init_weights(self.nets['netD']) - if self.is_train: - self.losses = {} # define loss functions - self.criterionGAN = GANLoss(cfg.model.gan_mode) - - # build optimizers - self.build_lr_scheduler() - self.optimizers['optimizer_G'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netG'].parameters()) - self.optimizers['optimizer_D'] = build_optimizer( - cfg.optimizer, - self.lr_scheduler, - parameter_list=self.nets['netD'].parameters()) - - def set_input(self, input): + if gan_criterion: + self.criterionGAN = build_criterion(gan_criterion) + + def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: @@ -131,7 +133,7 @@ class GANModel(BaseModel): self.loss_D_real = self.criterionGAN(pred_real, True, True) # combine loss and calculate gradients - if self.cfg.model.gan_mode in ['vanilla', 'lsgan']: + if self.criterionGAN.gan_mode in ['vanilla', 'lsgan']: self.loss_D = self.loss_D + (self.loss_D_fake + self.loss_D_real) * 0.5 else: @@ -159,34 +161,34 @@ class GANModel(BaseModel): self.losses['G_adv_loss'] = self.loss_G_GAN - def optimize_parameters(self): + def train_iter(self, optimizers=None): # compute fake images: G(imgs) self.forward() # update D self.set_requires_grad(self.nets['netD'], True) - self.optimizers['optimizer_D'].clear_grad() + optimizers['optimizer_D'].clear_grad() self.backward_D() - self.optimizers['optimizer_D'].step() + optimizers['optimizer_D'].step() self.set_requires_grad(self.nets['netD'], False) # weight clip - if self.cfg.model.gan_mode == 'wgan': + if self.criterionGAN.gan_mode == 'wgan': with paddle.no_grad(): for p in self.nets['netD'].parameters(): p[:] = p.clip(-0.01, 0.01) - if self.step % self.n_critic == 0: + if self.iter > self.disc_start_iters and self.iter % self.disc_iters == 0: # update G - self.optimizers['optimizer_G'].clear_grad() + optimizers['optimizer_G'].clear_grad() self.backward_G() - self.optimizers['optimizer_G'].step() + optimizers['optimizer_G'].step() - if self.step % self.visual_interval == 0: + if self.iter % self.visual_interval == 0: with paddle.no_grad(): self.visual_items['fixed_generated_imgs'] = make_grid( self.nets['netG'](*self.G_fixed_inputs), self.samples_every_row) - self.step += 1 + self.iter += 1