# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import random import paddle import paddle.nn as nn from .base_model import BaseModel from .builder import MODELS from .criterions import build_criterion from .generators.builder import build_generator from .discriminators.builder import build_discriminator from ..solver import build_lr_scheduler, build_optimizer def r1_penalty(real_pred, real_img): """ R1 regularization for discriminator. The core idea is to penalize the gradient on real data alone: when the generator distribution produces the true data distribution and the discriminator is equal to 0 on the data manifold, the gradient penalty ensures that the discriminator cannot create a non-zero gradient orthogonal to the data manifold without suffering a loss in the GAN game. Ref: Eq. 9 in Which training methods for GANs do actually converge. """ grad_real = paddle.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] grad_penalty = (grad_real * grad_real).reshape([grad_real.shape[0], -1]).sum(1).mean() return grad_penalty def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): noise = paddle.randn(fake_img.shape) / math.sqrt( fake_img.shape[2] * fake_img.shape[3]) grad = paddle.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] path_lengths = paddle.sqrt((grad * grad).sum(2).mean(1)) path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) path_penalty = ((path_lengths - path_mean) * (path_lengths - path_mean)).mean() return path_penalty, path_lengths.detach().mean(), path_mean.detach() @MODELS.register() class StyleGAN2Model(BaseModel): """ This class implements the StyleGANV2 model, for learning image-to-image translation without paired data. StyleGAN2 paper: https://arxiv.org/pdf/1912.04958.pdf """ def __init__(self, generator, discriminator=None, gan_criterion=None, num_style_feat=512, mixing_prob=0.9, r1_reg_weight=10., path_reg_weight=2., path_batch_shrink=2., params=None): """Initialize the CycleGAN class. Args: generator (dict): config of generator. discriminator (dict): config of discriminator. gan_criterion (dict): config of gan criterion. """ super(StyleGAN2Model, self).__init__(params) self.gen_iters = 4 if self.params is None else self.params.get( 'gen_iters', 4) self.disc_iters = 16 if self.params is None else self.params.get( 'disc_iters', 16) self.disc_start_iters = (0 if self.params is None else self.params.get( 'disc_start_iters', 0)) self.visual_iters = (500 if self.params is None else self.params.get( 'visual_iters', 500)) self.mixing_prob = mixing_prob self.num_style_feat = num_style_feat self.r1_reg_weight = r1_reg_weight self.path_reg_weight = path_reg_weight self.path_batch_shrink = path_batch_shrink self.mean_path_length = 0 self.nets['gen'] = build_generator(generator) # define discriminators if discriminator: self.nets['disc'] = build_discriminator(discriminator) self.nets['gen_ema'] = build_generator(generator) self.model_ema(0) self.nets['gen'].train() self.nets['gen_ema'].eval() self.nets['disc'].train() self.current_iter = 1 # define loss functions if gan_criterion: self.gan_criterion = build_criterion(gan_criterion) def setup_lr_schedulers(self, cfg): self.lr_scheduler = dict() gen_cfg = cfg.copy() net_g_reg_ratio = self.gen_iters / (self.gen_iters + 1) gen_cfg['learning_rate'] = cfg['learning_rate'] * net_g_reg_ratio self.lr_scheduler['gen'] = build_lr_scheduler(gen_cfg) disc_cfg = cfg.copy() net_d_reg_ratio = self.disc_iters / (self.disc_iters + 1) disc_cfg['learning_rate'] = cfg['learning_rate'] * net_d_reg_ratio self.lr_scheduler['disc'] = build_lr_scheduler(disc_cfg) return self.lr_scheduler def setup_optimizers(self, lr, cfg): for opt_name, opt_cfg in cfg.items(): if opt_name == 'optimG': _lr = lr['gen'] elif opt_name == 'optimD': _lr = lr['disc'] else: raise ValueError("opt name must be in ['optimG', optimD]") cfg_ = opt_cfg.copy() net_names = cfg_.pop('net_names') parameters = [] for net_name in net_names: parameters += self.nets[net_name].parameters() self.optimizers[opt_name] = build_optimizer(cfg_, _lr, parameters) return self.optimizers def get_bare_model(self, net): """Get bare model, especially under wrapping with DataParallel. """ if isinstance(net, (paddle.DataParallel)): net = net._layers return net def model_ema(self, decay=0.999): net_g = self.get_bare_model(self.nets['gen']) net_g_params = dict(net_g.named_parameters()) neg_g_ema = self.get_bare_model(self.nets['gen_ema']) net_g_ema_params = dict(neg_g_ema.named_parameters()) for k in net_g_ema_params.keys(): net_g_ema_params[k].set_value(net_g_ema_params[k] * (decay) + (net_g_params[k] * (1 - decay))) def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Args: input (dict): include the data itself and its metadata information. """ self.real_img = paddle.to_tensor(input['A']) def forward(self): """Run forward pass; called by both functions and .""" pass def make_noise(self, batch, num_noise): if num_noise == 1: noises = paddle.randn([batch, self.num_style_feat]) else: noises = [] for _ in range(num_noise): noises.append(paddle.randn([batch, self.num_style_feat])) return noises def mixing_noise(self, batch, prob): if random.random() < prob: return self.make_noise(batch, 2) else: return [self.make_noise(batch, 1)] def train_iter(self, optimizers=None): current_iter = self.current_iter self.set_requires_grad(self.nets['disc'], True) optimizers['optimD'].clear_grad() batch = self.real_img.shape[0] noise = self.mixing_noise(batch, self.mixing_prob) fake_img, _ = self.nets['gen'](noise) self.visual_items['real_img'] = self.real_img self.visual_items['fake_img'] = fake_img fake_pred = self.nets['disc'](fake_img.detach()) real_pred = self.nets['disc'](self.real_img) # wgan loss with softplus (logistic loss) for discriminator l_d_total = 0. l_d = self.gan_criterion(real_pred, True, is_disc=True) + self.gan_criterion( fake_pred, False, is_disc=True) self.losses['l_d'] = l_d # In wgan, real_score should be positive and fake_score should be # negative self.losses['real_score'] = real_pred.detach().mean() self.losses['fake_score'] = fake_pred.detach().mean() l_d_total += l_d if current_iter % self.disc_iters == 0: self.real_img.stop_gradient = False real_pred = self.nets['disc'](self.real_img) l_d_r1 = r1_penalty(real_pred, self.real_img) l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.disc_iters + 0 * real_pred[0]) self.losses['l_d_r1'] = l_d_r1.detach().mean() l_d_total += l_d_r1 l_d_total.backward() optimizers['optimD'].step() self.set_requires_grad(self.nets['disc'], False) optimizers['optimG'].clear_grad() noise = self.mixing_noise(batch, self.mixing_prob) fake_img, _ = self.nets['gen'](noise) fake_pred = self.nets['disc'](fake_img) # wgan loss with softplus (non-saturating loss) for generator l_g_total = 0. l_g = self.gan_criterion(fake_pred, True, is_disc=False) self.losses['l_g'] = l_g l_g_total += l_g if current_iter % self.gen_iters == 0: path_batch_size = max(1, batch // self.path_batch_shrink) noise = self.mixing_noise(path_batch_size, self.mixing_prob) fake_img, latents = self.nets['gen'](noise, return_latents=True) l_g_path, path_lengths, self.mean_path_length = g_path_regularize( fake_img, latents, self.mean_path_length) l_g_path = (self.path_reg_weight * self.gen_iters * l_g_path + 0 * fake_img[0, 0, 0, 0]) l_g_total += l_g_path self.losses['l_g_path'] = l_g_path.detach().mean() self.losses['path_length'] = path_lengths l_g_total.backward() optimizers['optimG'].step() # EMA self.model_ema(decay=0.5**(32 / (10 * 1000))) if self.current_iter % self.visual_iters: sample_z = [self.make_noise(1, 1)] sample, _ = self.nets['gen_ema'](sample_z) self.visual_items['fake_img_ema'] = sample self.current_iter += 1