styleganv2_model.py 10.3 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import paddle
import paddle.nn as nn
from .base_model import BaseModel

from .builder import MODELS
from .criterions import build_criterion
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer


def r1_penalty(real_pred, real_img):
    """
    R1 regularization for discriminator. The core idea is to
    penalize the gradient on real data alone: when the
    generator distribution produces the true data distribution
    and the discriminator is equal to 0 on the data manifold, the
    gradient penalty ensures that the discriminator cannot create
    a non-zero gradient orthogonal to the data manifold without
    suffering a loss in the GAN game.

    Ref:
    Eq. 9 in Which training methods for GANs do actually converge.
    """

    grad_real = paddle.grad(outputs=real_pred.sum(),
                            inputs=real_img,
                            create_graph=True)[0]
    grad_penalty = (grad_real * grad_real).reshape([grad_real.shape[0],
                                                    -1]).sum(1).mean()
    return grad_penalty


def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
    noise = paddle.randn(fake_img.shape) / math.sqrt(
        fake_img.shape[2] * fake_img.shape[3])
    grad = paddle.grad(outputs=(fake_img * noise).sum(),
                       inputs=latents,
                       create_graph=True)[0]
    path_lengths = paddle.sqrt((grad * grad).sum(2).mean(1))

    path_mean = mean_path_length + decay * (path_lengths.mean() -
                                            mean_path_length)

    path_penalty = ((path_lengths - path_mean) *
                    (path_lengths - path_mean)).mean()

    return path_penalty, path_lengths.detach().mean(), path_mean.detach()


@MODELS.register()
class StyleGAN2Model(BaseModel):
    """
    This class implements the StyleGANV2 model, for learning image-to-image translation without paired data.

    StyleGAN2 paper: https://arxiv.org/pdf/1912.04958.pdf
    """
    def __init__(self,
                 generator,
                 discriminator=None,
                 gan_criterion=None,
                 num_style_feat=512,
                 mixing_prob=0.9,
                 r1_reg_weight=10.,
                 path_reg_weight=2.,
                 path_batch_shrink=2.,
                 params=None):
        """Initialize the CycleGAN class.

        Args:
            generator (dict): config of generator.
            discriminator (dict): config of discriminator.
            gan_criterion (dict): config of gan criterion.
        """
        super(StyleGAN2Model, self).__init__(params)
        self.gen_iters = 4 if self.params is None else self.params.get(
            'gen_iters', 4)
        self.disc_iters = 16 if self.params is None else self.params.get(
            'disc_iters', 16)
        self.disc_start_iters = (0 if self.params is None else self.params.get(
            'disc_start_iters', 0))

        self.visual_iters = (500 if self.params is None else self.params.get(
            'visual_iters', 500))

        self.mixing_prob = mixing_prob
        self.num_style_feat = num_style_feat
        self.r1_reg_weight = r1_reg_weight

        self.path_reg_weight = path_reg_weight
        self.path_batch_shrink = path_batch_shrink
        self.mean_path_length = 0

        self.nets['gen'] = build_generator(generator)

        # define discriminators
        if discriminator:
            self.nets['disc'] = build_discriminator(discriminator)

            self.nets['gen_ema'] = build_generator(generator)
            self.model_ema(0)

            self.nets['gen'].train()
            self.nets['gen_ema'].eval()
            self.nets['disc'].train()
            self.current_iter = 1

        # define loss functions
        if gan_criterion:
            self.gan_criterion = build_criterion(gan_criterion)

    def setup_lr_schedulers(self, cfg):
        self.lr_scheduler = dict()
        gen_cfg = cfg.copy()
        net_g_reg_ratio = self.gen_iters / (self.gen_iters + 1)
        gen_cfg['learning_rate'] = cfg['learning_rate'] * net_g_reg_ratio
        self.lr_scheduler['gen'] = build_lr_scheduler(gen_cfg)

        disc_cfg = cfg.copy()
        net_d_reg_ratio = self.disc_iters / (self.disc_iters + 1)
        disc_cfg['learning_rate'] = cfg['learning_rate'] * net_d_reg_ratio
        self.lr_scheduler['disc'] = build_lr_scheduler(disc_cfg)
        return self.lr_scheduler

    def setup_optimizers(self, lr, cfg):
        for opt_name, opt_cfg in cfg.items():
            if opt_name == 'optimG':
                _lr = lr['gen']
            elif opt_name == 'optimD':
                _lr = lr['disc']
            else:
                raise ValueError("opt name must be in ['optimG', optimD]")

            cfg_ = opt_cfg.copy()
            net_names = cfg_.pop('net_names')
            parameters = []
            for net_name in net_names:
                parameters += self.nets[net_name].parameters()
            self.optimizers[opt_name] = build_optimizer(cfg_, _lr, parameters)

        return self.optimizers

    def get_bare_model(self, net):
        """Get bare model, especially under wrapping with DataParallel.
        """
        if isinstance(net, (paddle.DataParallel)):
            net = net._layers
        return net

    def model_ema(self, decay=0.999):
        net_g = self.get_bare_model(self.nets['gen'])
        net_g_params = dict(net_g.named_parameters())

        neg_g_ema = self.get_bare_model(self.nets['gen_ema'])
        net_g_ema_params = dict(neg_g_ema.named_parameters())

        for k in net_g_ema_params.keys():
            net_g_ema_params[k].set_value(net_g_ema_params[k] * (decay) +
                                          (net_g_params[k] * (1 - decay)))

    def setup_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Args:
            input (dict): include the data itself and its metadata information.

        """
        self.real_img = paddle.fluid.dygraph.to_variable(input['A'])

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        pass

    def make_noise(self, batch, num_noise):
        if num_noise == 1:
            noises = paddle.randn([batch, self.num_style_feat])
        else:
            noises = []
            for _ in range(num_noise):
                noises.append(paddle.randn([batch, self.num_style_feat]))

        return noises

    def mixing_noise(self, batch, prob):
        if random.random() < prob:
            return self.make_noise(batch, 2)
        else:
            return [self.make_noise(batch, 1)]

    def train_iter(self, optimizers=None):
        current_iter = self.current_iter
        self.set_requires_grad(self.nets['disc'], True)
        optimizers['optimD'].clear_grad()
        batch = self.real_img.shape[0]
        noise = self.mixing_noise(batch, self.mixing_prob)

        fake_img, _ = self.nets['gen'](noise)
        self.visual_items['real_img'] = self.real_img
        self.visual_items['fake_img'] = fake_img
        fake_pred = self.nets['disc'](fake_img.detach())

        real_pred = self.nets['disc'](self.real_img)
        # wgan loss with softplus (logistic loss) for discriminator
        l_d_total = 0.
        l_d = self.gan_criterion(real_pred, True,
                                 is_disc=True) + self.gan_criterion(
                                     fake_pred, False, is_disc=True)
        self.losses['l_d'] = l_d
        # In wgan, real_score should be positive and fake_score should be
        # negative
        self.losses['real_score'] = real_pred.detach().mean()
        self.losses['fake_score'] = fake_pred.detach().mean()

        l_d_total += l_d

        if current_iter % self.disc_iters == 0:
            self.real_img.stop_gradient = False
            real_pred = self.nets['disc'](self.real_img)
            l_d_r1 = r1_penalty(real_pred, self.real_img)
            l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.disc_iters +
                      0 * real_pred[0])

            self.losses['l_d_r1'] = l_d_r1.detach().mean()

            l_d_total += l_d_r1
        l_d_total.backward()

        optimizers['optimD'].step()

        self.set_requires_grad(self.nets['disc'], False)
        optimizers['optimG'].clear_grad()

        noise = self.mixing_noise(batch, self.mixing_prob)
        fake_img, _ = self.nets['gen'](noise)
        fake_pred = self.nets['disc'](fake_img)

        # wgan loss with softplus (non-saturating loss) for generator
        l_g_total = 0.
        l_g = self.gan_criterion(fake_pred, True, is_disc=False)
        self.losses['l_g'] = l_g

        l_g_total += l_g
        if current_iter % self.gen_iters == 0:
            path_batch_size = max(1, batch // self.path_batch_shrink)
            noise = self.mixing_noise(path_batch_size, self.mixing_prob)
            fake_img, latents = self.nets['gen'](noise, return_latents=True)
            l_g_path, path_lengths, self.mean_path_length = g_path_regularize(
                fake_img, latents, self.mean_path_length)

            l_g_path = (self.path_reg_weight * self.gen_iters * l_g_path +
                        0 * fake_img[0, 0, 0, 0])

            l_g_total += l_g_path
            self.losses['l_g_path'] = l_g_path.detach().mean()
            self.losses['path_length'] = path_lengths
        l_g_total.backward()
        optimizers['optimG'].step()

        # EMA
        self.model_ema(decay=0.5**(32 / (10 * 1000)))

        if self.current_iter % self.visual_iters:
            sample_z = [self.make_noise(1, 1)]
            sample, _ = self.nets['gen_ema'](sample_z)
            self.visual_items['fake_img_ema'] = sample

        self.current_iter += 1