diff --git a/configs/stylegan_v2_256_ffhq.yaml b/configs/stylegan_v2_256_ffhq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d87268c97d174e7c457230ad71e7cb87928c5984 --- /dev/null +++ b/configs/stylegan_v2_256_ffhq.yaml @@ -0,0 +1,71 @@ +total_iters: 800000 +output_dir: output_dir + +model: + name: StyleGAN2Model + generator: + name: StyleGANv2Generator + size: 256 + style_dim: 512 + n_mlp: 8 + discriminator: + name: StyleGANv2Discriminator + size: 256 + gan_criterion: + name: GANLoss + gan_mode: logistic + loss_weight: !!float 1 + # r1 regularization for discriminator + r1_reg_weight: 10. + # path length regularization for generator + path_batch_shrink: 2. + path_reg_weight: 2. + params: + gen_iters: 4 + disc_iters: 16 + +dataset: + train: + name: SingleDataset + dataroot: data/ffhq/images256x256/ + num_workers: 3 + batch_size: 3 + preprocess: + - name: LoadImageFromFile + key: A + - name: Transforms + input_keys: [A] + pipeline: + - name: RandomHorizontalFlip + - name: Transpose + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + +lr_scheduler: + name: MultiStepDecay + learning_rate: 0.002 + milestones: [600000] + gamma: 0.5 + +optimizer: + optimG: + name: Adam + beta1: 0.0 + beta2: 0.792 + net_names: + - gen + optimD: + name: Adam + net_names: + - disc + beta1: 0.0 + beta2: 0.9317647058823529 + + +log_config: + interval: 50 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/ppgan/datasets/preprocess/transforms.py b/ppgan/datasets/preprocess/transforms.py index 6ff4196f177d7aca6dc3cd8bb4442a7771181bfe..45901932d96499cb5520f5bc3b5f4cb705162a0c 100644 --- a/ppgan/datasets/preprocess/transforms.py +++ b/ppgan/datasets/preprocess/transforms.py @@ -59,19 +59,21 @@ class Transforms(): data = tuple(data) for transform in self.transforms: data = transform(data) - if hasattr(transform, 'params') and isinstance( transform.params, dict): datas.update(transform.params) + if len(self.input_keys) > 1: + for i, k in enumerate(self.input_keys): + datas[k] = data[i] + else: + datas[k] = data + if self.output_keys is not None: for i, k in enumerate(self.output_keys): datas[k] = data[i] return datas - for i, k in enumerate(self.input_keys): - datas[k] = data[i] - return datas diff --git a/ppgan/datasets/single_dataset.py b/ppgan/datasets/single_dataset.py index ad67c440d93a4e8a50ff541b03471d6aafa724b1..98661567e905f1258c41ee3ac33a535bdd25278f 100644 --- a/ppgan/datasets/single_dataset.py +++ b/ppgan/datasets/single_dataset.py @@ -27,7 +27,7 @@ class SingleDataset(BaseDataset): dataroot (str): Directory of dataset. preprocess (list[dict]): A sequence of data preprocess config. """ - super(SingleDataset).__init__(self, preprocess) + super(SingleDataset, self).__init__(preprocess) self.dataroot = dataroot self.data_infos = self.prepare_data_infos() diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 18c27720d63d9b48245ff1fc14cbaefcb65b0497..d3b303c9da7951a81c3665f49efad16da7a1193c 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -123,6 +123,8 @@ class Trainer: self.batch_id = 0 self.global_steps = 0 self.weight_interval = cfg.snapshot_config.interval + if self.by_epoch: + self.weight_interval *= self.iters_per_epoch self.log_interval = cfg.log_config.interval self.visual_interval = cfg.log_config.visiual_interval if self.by_epoch: @@ -143,6 +145,17 @@ class Trainer: for net_name, net in self.model.nets.items(): self.model.nets[net_name] = paddle.DataParallel(net, strategy) + def learning_rate_scheduler_step(self): + if isinstance(self.model.lr_scheduler, dict): + for lr_scheduler in self.model.lr_scheduler.values(): + lr_scheduler.step() + elif isinstance(self.model.lr_scheduler, + paddle.optimizer.lr.LRScheduler): + self.model.lr_scheduler.step() + else: + raise ValueError( + 'lr schedulter must be a dict or an instance of LRScheduler') + def train(self): reader_cost_averager = TimeAverager() batch_cost_averager = TimeAverager() @@ -179,7 +192,7 @@ class Trainer: if self.current_iter % self.visual_interval == 0: self.visual('visual_train') - self.model.lr_scheduler.step() + self.learning_rate_scheduler_step() if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: self.test() diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index bb1f27897ef27c2c400c22a433107049042e2506..bb10a7b78fab5d10ba6a7285817654ffe15e8fbd 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -22,5 +22,6 @@ from .esrgan_model import ESRGAN from .ugatit_model import UGATITModel from .dc_gan_model import DCGANModel from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel +from .styleganv2_model import StyleGAN2Model from .wav2lip_model import Wav2LipModel from .wav2lip_hq_model import Wav2LipModelHq diff --git a/ppgan/models/discriminators/discriminator_styleganv2.py b/ppgan/models/discriminators/discriminator_styleganv2.py index a06e1f60927d02de8021343daf345a4bd78b66fa..038d39ab5f24374c9d1ec06675a44d8f94e80c32 100644 --- a/ppgan/models/discriminators/discriminator_styleganv2.py +++ b/ppgan/models/discriminators/discriminator_styleganv2.py @@ -35,22 +35,22 @@ class ConvLayer(nn.Sequential): activate=True, ): layers = [] - + if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 - + layers.append(Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))) - + stride = 2 self.padding = 0 - + else: stride = 1 self.padding = kernel_size // 2 - + layers.append( EqualConv2D( in_channel, @@ -59,41 +59,58 @@ class ConvLayer(nn.Sequential): padding=self.padding, stride=stride, bias=bias and not activate, - ) - ) - + )) + if activate: layers.append(FusedLeakyReLU(out_channel, bias=bias)) - + super().__init__(*layers) - - + + class ResBlock(nn.Layer): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): super().__init__() - + self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) - - self.skip = ConvLayer( - in_channel, out_channel, 1, downsample=True, activate=False, bias=False - ) - + + self.skip = ConvLayer(in_channel, + out_channel, + 1, + downsample=True, + activate=False, + bias=False) + def forward(self, input): out = self.conv1(input) out = self.conv2(out) - + skip = self.skip(input) out = (out + skip) / math.sqrt(2) - + return out - - + + +# temporally solve pow double grad problem +def var(x, axis=None, unbiased=True, keepdim=False, name=None): + + u = paddle.mean(x, axis, True, name) + out = paddle.sum((x - u) * (x - u), axis, keepdim=keepdim, name=name) + + n = paddle.cast(paddle.numel(x), x.dtype) \ + / paddle.cast(paddle.numel(out), x.dtype) + if unbiased: + one_const = paddle.ones([1], x.dtype) + n = paddle.where(n > one_const, n - 1., one_const) + out /= n + return out + + @DISCRIMINATORS.register() class StyleGANv2Discriminator(nn.Layer): def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): super().__init__() - + channels = { 4: 512, 8: 512, @@ -105,47 +122,48 @@ class StyleGANv2Discriminator(nn.Layer): 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } - + convs = [ConvLayer(3, channels[size], 1)] - + log_size = int(math.log(size, 2)) - + in_channel = channels[size] - + for i in range(log_size, 2, -1): - out_channel = channels[2 ** (i - 1)] - + out_channel = channels[2**(i - 1)] + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) - + in_channel = out_channel - + self.convs = nn.Sequential(*convs) - + self.stddev_group = 4 self.stddev_feat = 1 - + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_linear = nn.Sequential( - EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), + EqualLinear(channels[4] * 4 * 4, + channels[4], + activation="fused_lrelu"), EqualLinear(channels[4], 1), ) - + def forward(self, input): out = self.convs(input) - + batch, channel, height, width = out.shape group = min(batch, self.stddev_group) - stddev = out.reshape(( - group, -1, self.stddev_feat, channel // self.stddev_feat, height, width - )) - stddev = paddle.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = out.reshape((group, -1, self.stddev_feat, + channel // self.stddev_feat, height, width)) + stddev = paddle.sqrt(var(stddev, 0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2) stddev = stddev.tile((group, 1, height, width)) out = paddle.concat([out, stddev], 1) - + out = self.final_conv(out) - + out = out.reshape((batch, -1)) out = self.final_linear(out) - + return out diff --git a/ppgan/models/generators/generator_styleganv2.py b/ppgan/models/generators/generator_styleganv2.py index 0c0ccbaaf4cfa969792523de6ff4439876e41c09..cabfe340e2dd38e74e1c24a01b4d8a979b24388d 100644 --- a/ppgan/models/generators/generator_styleganv2.py +++ b/ppgan/models/generators/generator_styleganv2.py @@ -27,11 +27,12 @@ from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur class PixelNorm(nn.Layer): def __init__(self): super().__init__() - + def forward(self, input): - return input * paddle.rsqrt(paddle.mean(input ** 2, 1, keepdim=True) + 1e-8) - - + return input * paddle.rsqrt( + paddle.mean(input * input, 1, keepdim=True) + 1e-8) + + class ModulatedConv2D(nn.Layer): def __init__( self, @@ -45,75 +46,78 @@ class ModulatedConv2D(nn.Layer): blur_kernel=[1, 3, 3, 1], ): super().__init__() - + self.eps = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel self.upsample = upsample self.downsample = downsample - + if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 - - self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) - + + self.blur = Upfirdn2dBlur(blur_kernel, + pad=(pad0, pad1), + upsample_factor=factor) + if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 - + self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)) - - fan_in = in_channel * kernel_size ** 2 + + fan_in = in_channel * (kernel_size * kernel_size) self.scale = 1 / math.sqrt(fan_in) self.padding = kernel_size // 2 - + self.weight = self.create_parameter( - (1, out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal() - ) - + (1, out_channel, in_channel, kernel_size, kernel_size), + default_initializer=nn.initializer.Normal()) + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) - + self.demodulate = demodulate - + def __repr__(self): return ( f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " - f"upsample={self.upsample}, downsample={self.downsample})" - ) - + f"upsample={self.upsample}, downsample={self.downsample})") + def forward(self, input, style): batch, in_channel, height, width = input.shape - + style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1)) weight = self.scale * self.weight * style - + if self.demodulate: - demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8) weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1)) - - weight = weight.reshape(( - batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size - )) - + + weight = weight.reshape((batch * self.out_channel, in_channel, + self.kernel_size, self.kernel_size)) + if self.upsample: input = input.reshape((1, batch * in_channel, height, width)) - weight = weight.reshape(( - batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size - )) - weight = weight.transpose((0, 2, 1, 3, 4)).reshape(( - batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size - )) - out = F.conv2d_transpose(input, weight, padding=0, stride=2, groups=batch) + weight = weight.reshape((batch, self.out_channel, in_channel, + self.kernel_size, self.kernel_size)) + weight = weight.transpose((0, 2, 1, 3, 4)).reshape( + (batch * in_channel, self.out_channel, self.kernel_size, + self.kernel_size)) + out = F.conv2d_transpose(input, + weight, + padding=0, + stride=2, + groups=batch) _, _, height, width = out.shape out = out.reshape((batch, self.out_channel, height, width)) out = self.blur(out) - + elif self.downsample: input = self.blur(input) _, _, height, width = input.shape @@ -121,43 +125,46 @@ class ModulatedConv2D(nn.Layer): out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.reshape((batch, self.out_channel, height, width)) - + else: input = input.reshape((1, batch * in_channel, height, width)) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.reshape((batch, self.out_channel, height, width)) - + return out - - + + class NoiseInjection(nn.Layer): def __init__(self): super().__init__() - - self.weight = self.create_parameter((1,), default_initializer=nn.initializer.Constant(0.0)) - + + self.weight = self.create_parameter( + (1, ), default_initializer=nn.initializer.Constant(0.0)) + def forward(self, image, noise=None): if noise is None: batch, _, height, width = image.shape noise = paddle.randn((batch, 1, height, width)) - + return image + self.weight * noise - - + + class ConstantInput(nn.Layer): def __init__(self, channel, size=4): super().__init__() - - self.input = self.create_parameter((1, channel, size, size), default_initializer=nn.initializer.Normal()) - + + self.input = self.create_parameter( + (1, channel, size, size), + default_initializer=nn.initializer.Normal()) + def forward(self, input): batch = input.shape[0] out = self.input.tile((batch, 1, 1, 1)) - + return out - - + + class StyledConv(nn.Layer): def __init__( self, @@ -170,7 +177,7 @@ class StyledConv(nn.Layer): demodulate=True, ): super().__init__() - + self.conv = ModulatedConv2D( in_channel, out_channel, @@ -180,40 +187,49 @@ class StyledConv(nn.Layer): blur_kernel=blur_kernel, demodulate=demodulate, ) - + self.noise = NoiseInjection() self.activate = FusedLeakyReLU(out_channel) - + def forward(self, input, style, noise=None): out = self.conv(input, style) out = self.noise(out, noise=noise) out = self.activate(out) - + return out - - + + class ToRGB(nn.Layer): - def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + def __init__(self, + in_channel, + style_dim, + upsample=True, + blur_kernel=[1, 3, 3, 1]): super().__init__() - + if upsample: self.upsample = Upfirdn2dUpsample(blur_kernel) - - self.conv = ModulatedConv2D(in_channel, 3, 1, style_dim, demodulate=False) - self.bias = self.create_parameter((1, 3, 1, 1), nn.initializer.Constant(0.0)) - + + self.conv = ModulatedConv2D(in_channel, + 3, + 1, + style_dim, + demodulate=False) + self.bias = self.create_parameter((1, 3, 1, 1), + nn.initializer.Constant(0.0)) + def forward(self, input, style, skip=None): out = self.conv(input, style) out = out + self.bias - + if skip is not None: skip = self.upsample(skip) - + out = out + skip - + return out - - + + @GENERATORS.register() class StyleGANv2Generator(nn.Layer): def __init__( @@ -226,22 +242,22 @@ class StyleGANv2Generator(nn.Layer): lr_mlp=0.01, ): super().__init__() - + self.size = size - + self.style_dim = style_dim - + layers = [PixelNorm()] - + for i in range(n_mlp): layers.append( - EqualLinear( - style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" - ) - ) - + EqualLinear(style_dim, + style_dim, + lr_mul=lr_mlp, + activation="fused_lrelu")) + self.style = nn.Sequential(*layers) - + self.channels = { 4: 512, 8: 512, @@ -253,31 +269,34 @@ class StyleGANv2Generator(nn.Layer): 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } - + self.input = ConstantInput(self.channels[4]) - self.conv1 = StyledConv( - self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel - ) + self.conv1 = StyledConv(self.channels[4], + self.channels[4], + 3, + style_dim, + blur_kernel=blur_kernel) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) - + self.log_size = int(math.log(size, 2)) self.num_layers = (self.log_size - 2) * 2 + 1 - + self.convs = nn.LayerList() self.upsamples = nn.LayerList() self.to_rgbs = nn.LayerList() self.noises = nn.Layer() - + in_channel = self.channels[4] - + for layer_idx in range(self.num_layers): res = (layer_idx + 5) // 2 - shape = [1, 1, 2 ** res, 2 ** res] - self.noises.register_buffer(f"noise_{layer_idx}", paddle.randn(shape)) - + shape = [1, 1, 2**res, 2**res] + self.noises.register_buffer(f"noise_{layer_idx}", + paddle.randn(shape)) + for i in range(3, self.log_size + 1): - out_channel = self.channels[2 ** i] - + out_channel = self.channels[2**i] + self.convs.append( StyledConv( in_channel, @@ -286,41 +305,39 @@ class StyleGANv2Generator(nn.Layer): style_dim, upsample=True, blur_kernel=blur_kernel, - ) - ) - + )) + self.convs.append( - StyledConv( - out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel - ) - ) - + StyledConv(out_channel, + out_channel, + 3, + style_dim, + blur_kernel=blur_kernel)) + self.to_rgbs.append(ToRGB(out_channel, style_dim)) - + in_channel = out_channel - + self.n_latent = self.log_size * 2 - 2 - + def make_noise(self): - noises = [paddle.randn((1, 1, 2 ** 2, 2 ** 2))] - + noises = [paddle.randn((1, 1, 2**2, 2**2))] + for i in range(3, self.log_size + 1): for _ in range(2): - noises.append(paddle.randn((1, 1, 2 ** i, 2 ** i))) - + noises.append(paddle.randn((1, 1, 2**i, 2**i))) + return noises - + def mean_latent(self, n_latent): - latent_in = paddle.randn(( - n_latent, self.style_dim - )) + latent_in = paddle.randn((n_latent, self.style_dim)) latent = self.style(latent_in).mean(0, keepdim=True) - + return latent - + def get_latent(self, input): return self.style(input) - + def forward( self, styles, @@ -334,62 +351,65 @@ class StyleGANv2Generator(nn.Layer): ): if not input_is_latent: styles = [self.style(s) for s in styles] - + if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ - getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) + getattr(self.noises, f"noise_{i}") + for i in range(self.num_layers) ] - + if truncation < 1: style_t = [] - + for style in styles: - style_t.append( - truncation_latent + truncation * (style - truncation_latent) - ) - + style_t.append(truncation_latent + truncation * + (style - truncation_latent)) + styles = style_t - + if len(styles) < 2: inject_index = self.n_latent - + if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).tile((1, inject_index, 1)) - + else: latent = styles[0] - + else: if inject_index is None: inject_index = random.randint(1, self.n_latent - 1) - + latent = styles[0].unsqueeze(1).tile((1, inject_index, 1)) - latent2 = styles[1].unsqueeze(1).tile((1, self.n_latent - inject_index, 1)) - + latent2 = styles[1].unsqueeze(1).tile( + (1, self.n_latent - inject_index, 1)) + latent = paddle.concat([latent, latent2], 1) - + out = self.input(latent) out = self.conv1(out, latent[:, 0], noise=noise[0]) - + skip = self.to_rgb1(out, latent[:, 1]) - + i = 1 - for conv1, conv2, noise1, noise2, to_rgb in zip( - self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs - ): + for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], + self.convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs): out = conv1(out, latent[:, i], noise=noise1) out = conv2(out, latent[:, i + 1], noise=noise2) skip = to_rgb(out, latent[:, i + 2], skip) - + i += 2 - + image = skip - + if return_latents: return image, latent - + else: return image, None diff --git a/ppgan/models/styleganv2_model.py b/ppgan/models/styleganv2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1f10ed0393d9eacce79b1e03350bca7dc1da96eb --- /dev/null +++ b/ppgan/models/styleganv2_model.py @@ -0,0 +1,282 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import random +import paddle +import paddle.nn as nn +from .base_model import BaseModel + +from .builder import MODELS +from .criterions import build_criterion +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from ..solver import build_lr_scheduler, build_optimizer + + +def r1_penalty(real_pred, real_img): + """ + R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + + grad_real = paddle.grad(outputs=real_pred.sum(), + inputs=real_img, + create_graph=True)[0] + grad_penalty = (grad_real * grad_real).reshape([grad_real.shape[0], + -1]).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = paddle.randn(fake_img.shape) / math.sqrt( + fake_img.shape[2] * fake_img.shape[3]) + grad = paddle.grad(outputs=(fake_img * noise).sum(), + inputs=latents, + create_graph=True)[0] + path_lengths = paddle.sqrt((grad * grad).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - + mean_path_length) + + path_penalty = ((path_lengths - path_mean) * + (path_lengths - path_mean)).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +@MODELS.register() +class StyleGAN2Model(BaseModel): + """ + This class implements the StyleGANV2 model, for learning image-to-image translation without paired data. + + StyleGAN2 paper: https://arxiv.org/pdf/1912.04958.pdf + """ + def __init__(self, + generator, + discriminator=None, + gan_criterion=None, + num_style_feat=512, + mixing_prob=0.9, + r1_reg_weight=10., + path_reg_weight=2., + path_batch_shrink=2., + params=None): + """Initialize the CycleGAN class. + + Args: + generator (dict): config of generator. + discriminator (dict): config of discriminator. + gan_criterion (dict): config of gan criterion. + """ + super(StyleGAN2Model, self).__init__(params) + self.gen_iters = 4 if self.params is None else self.params.get( + 'gen_iters', 4) + self.disc_iters = 16 if self.params is None else self.params.get( + 'disc_iters', 16) + self.disc_start_iters = (0 if self.params is None else self.params.get( + 'disc_start_iters', 0)) + + self.visual_iters = (500 if self.params is None else self.params.get( + 'visual_iters', 500)) + + self.mixing_prob = mixing_prob + self.num_style_feat = num_style_feat + self.r1_reg_weight = r1_reg_weight + + self.path_reg_weight = path_reg_weight + self.path_batch_shrink = path_batch_shrink + self.mean_path_length = 0 + + self.nets['gen'] = build_generator(generator) + + # define discriminators + if discriminator: + self.nets['disc'] = build_discriminator(discriminator) + + self.nets['gen_ema'] = build_generator(generator) + self.model_ema(0) + + self.nets['gen'].train() + self.nets['gen_ema'].eval() + self.nets['disc'].train() + self.current_iter = 1 + + # define loss functions + if gan_criterion: + self.gan_criterion = build_criterion(gan_criterion) + + def setup_lr_schedulers(self, cfg): + self.lr_scheduler = dict() + gen_cfg = cfg.copy() + net_g_reg_ratio = self.gen_iters / (self.gen_iters + 1) + gen_cfg['learning_rate'] = cfg['learning_rate'] * net_g_reg_ratio + self.lr_scheduler['gen'] = build_lr_scheduler(gen_cfg) + + disc_cfg = cfg.copy() + net_d_reg_ratio = self.disc_iters / (self.disc_iters + 1) + disc_cfg['learning_rate'] = cfg['learning_rate'] * net_d_reg_ratio + self.lr_scheduler['disc'] = build_lr_scheduler(disc_cfg) + return self.lr_scheduler + + def setup_optimizers(self, lr, cfg): + for opt_name, opt_cfg in cfg.items(): + if opt_name == 'optimG': + _lr = lr['gen'] + elif opt_name == 'optimD': + _lr = lr['disc'] + else: + raise ValueError("opt name must be in ['optimG', optimD]") + + cfg_ = opt_cfg.copy() + net_names = cfg_.pop('net_names') + parameters = [] + for net_name in net_names: + parameters += self.nets[net_name].parameters() + self.optimizers[opt_name] = build_optimizer(cfg_, _lr, parameters) + + return self.optimizers + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with DataParallel. + """ + if isinstance(net, (paddle.DataParallel)): + net = net._layers + return net + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.nets['gen']) + net_g_params = dict(net_g.named_parameters()) + + neg_g_ema = self.get_bare_model(self.nets['gen_ema']) + net_g_ema_params = dict(neg_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].set_value(net_g_ema_params[k] * (decay) + + (net_g_params[k] * (1 - decay))) + + def setup_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Args: + input (dict): include the data itself and its metadata information. + + """ + self.real_img = paddle.fluid.dygraph.to_variable(input['A']) + + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + def make_noise(self, batch, num_noise): + if num_noise == 1: + noises = paddle.randn([batch, self.num_style_feat]) + else: + noises = [] + for _ in range(num_noise): + noises.append(paddle.randn([batch, self.num_style_feat])) + + return noises + + def mixing_noise(self, batch, prob): + if random.random() < prob: + return self.make_noise(batch, 2) + else: + return [self.make_noise(batch, 1)] + + def train_iter(self, optimizers=None): + current_iter = self.current_iter + self.set_requires_grad(self.nets['disc'], True) + optimizers['optimD'].clear_grad() + batch = self.real_img.shape[0] + noise = self.mixing_noise(batch, self.mixing_prob) + + fake_img, _ = self.nets['gen'](noise) + self.visual_items['real_img'] = self.real_img + self.visual_items['fake_img'] = fake_img + fake_pred = self.nets['disc'](fake_img.detach()) + + real_pred = self.nets['disc'](self.real_img) + # wgan loss with softplus (logistic loss) for discriminator + l_d_total = 0. + l_d = self.gan_criterion(real_pred, True, + is_disc=True) + self.gan_criterion( + fake_pred, False, is_disc=True) + self.losses['l_d'] = l_d + # In wgan, real_score should be positive and fake_score should be + # negative + self.losses['real_score'] = real_pred.detach().mean() + self.losses['fake_score'] = fake_pred.detach().mean() + + l_d_total += l_d + + if current_iter % self.disc_iters == 0: + self.real_img.stop_gradient = False + real_pred = self.nets['disc'](self.real_img) + l_d_r1 = r1_penalty(real_pred, self.real_img) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.disc_iters + + 0 * real_pred[0]) + + self.losses['l_d_r1'] = l_d_r1.detach().mean() + + l_d_total += l_d_r1 + l_d_total.backward() + + optimizers['optimD'].step() + + self.set_requires_grad(self.nets['disc'], False) + optimizers['optimG'].clear_grad() + + noise = self.mixing_noise(batch, self.mixing_prob) + fake_img, _ = self.nets['gen'](noise) + fake_pred = self.nets['disc'](fake_img) + + # wgan loss with softplus (non-saturating loss) for generator + l_g_total = 0. + l_g = self.gan_criterion(fake_pred, True, is_disc=False) + self.losses['l_g'] = l_g + + l_g_total += l_g + if current_iter % self.gen_iters == 0: + path_batch_size = max(1, batch // self.path_batch_shrink) + noise = self.mixing_noise(path_batch_size, self.mixing_prob) + fake_img, latents = self.nets['gen'](noise, return_latents=True) + l_g_path, path_lengths, self.mean_path_length = g_path_regularize( + fake_img, latents, self.mean_path_length) + + l_g_path = (self.path_reg_weight * self.gen_iters * l_g_path + + 0 * fake_img[0, 0, 0, 0]) + + l_g_total += l_g_path + self.losses['l_g_path'] = l_g_path.detach().mean() + self.losses['path_length'] = path_lengths + l_g_total.backward() + optimizers['optimG'].step() + + # EMA + self.model_ema(decay=0.5**(32 / (10 * 1000))) + + if self.current_iter % self.visual_iters: + sample_z = [self.make_noise(1, 1)] + sample, _ = self.nets['gen_ema'](sample_z) + self.visual_items['fake_img_ema'] = sample + + self.current_iter += 1 diff --git a/ppgan/modules/equalized.py b/ppgan/modules/equalized.py index 7280ab0e212f7309f2125e19e83cca59b096f31e..7d2eef17ba6febb1c12a0ff2d5e685191b7d5e2f 100644 --- a/ppgan/modules/equalized.py +++ b/ppgan/modules/equalized.py @@ -24,25 +24,30 @@ class EqualConv2D(nn.Layer): """This convolutional layer class stabilizes the learning rate changes of its parameters. Equalizing learning rate keeps the weights in the network at a similar scale during training. """ - def __init__( - self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True - ): + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + bias=True): super().__init__() - + self.weight = self.create_parameter( - (out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal() - ) - self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) - + (out_channel, in_channel, kernel_size, kernel_size), + default_initializer=nn.initializer.Normal()) + self.scale = 1 / math.sqrt(in_channel * (kernel_size * kernel_size)) + self.stride = stride self.padding = padding - + if bias: - self.bias = self.create_parameter((out_channel,), nn.initializer.Constant(0.0)) - + self.bias = self.create_parameter((out_channel, ), + nn.initializer.Constant(0.0)) + else: self.bias = None - + def forward(self, input): out = F.conv2d( input, @@ -51,51 +56,57 @@ class EqualConv2D(nn.Layer): stride=self.stride, padding=self.padding, ) - + return out - + def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" ) - - + + class EqualLinear(nn.Layer): """This linear layer class stabilizes the learning rate changes of its parameters. Equalizing learning rate keeps the weights in the network at a similar scale during training. """ - def __init__( - self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None - ): + def __init__(self, + in_dim, + out_dim, + bias=True, + bias_init=0, + lr_mul=1, + activation=None): super().__init__() - - self.weight = self.create_parameter((in_dim, out_dim), default_initializer=nn.initializer.Normal()) - self.weight[:] = (self.weight / lr_mul).detach() - + + self.weight = self.create_parameter( + (in_dim, out_dim), default_initializer=nn.initializer.Normal()) + self.weight.set_value((self.weight / lr_mul)) + if bias: - self.bias = self.create_parameter((out_dim,), nn.initializer.Constant(bias_init)) - + self.bias = self.create_parameter( + (out_dim, ), nn.initializer.Constant(bias_init)) + else: self.bias = None - + self.activation = activation - + self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul - + def forward(self, input): if self.activation: out = F.linear(input, self.weight * self.scale) out = fused_leaky_relu(out, self.bias * self.lr_mul) - + else: - out = F.linear( - input, self.weight * self.scale, bias=self.bias * self.lr_mul - ) - + out = F.linear(input, + self.weight * self.scale, + bias=self.bias * self.lr_mul) + return out - + def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})" diff --git a/ppgan/modules/upfirdn2d.py b/ppgan/modules/upfirdn2d.py index 856378a62dd14c613787fd9ecead77d036b27467..ac34a889b279a1cf439b866192807ffad5f04571 100644 --- a/ppgan/modules/upfirdn2d.py +++ b/ppgan/modules/upfirdn2d.py @@ -15,37 +15,35 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F - - -def upfirdn2d_native( - input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 -): + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1): _, channel, in_h, in_w = input.shape input = input.reshape((-1, in_h, in_w, 1)) - + _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape - + out = input.reshape((-1, in_h, 1, in_w, 1, minor)) - out = out.transpose((0,1,3,5,2,4)) - out = out.reshape((-1,1,1,1)) + out = out.transpose((0, 1, 3, 5, 2, 4)) + out = out.reshape((-1, 1, 1, 1)) out = F.pad(out, [0, up_x - 1, 0, up_y - 1]) out = out.reshape((-1, in_h, in_w, minor, up_y, up_x)) - out = out.transpose((0,3,1,4,2,5)) + out = out.transpose((0, 3, 1, 4, 2, 5)) out = out.reshape((-1, minor, in_h * up_y, in_w * up_x)) - + out = F.pad( - out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] - ) - out = out[ - :,:, - max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0), - ] - - out = out.reshape(( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] - )) + out, [max(pad_x0, 0), + max(pad_x1, 0), + max(pad_y0, 0), + max(pad_y1, 0)]) + out = out[:, :, + max(-pad_y0, 0):out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape( + ([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])) w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w)) out = F.conv2d(out, w) out = out.reshape(( @@ -56,88 +54,95 @@ def upfirdn2d_native( )) out = out.transpose((0, 2, 3, 1)) out = out[:, ::down_y, ::down_x, :] - + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - + return out.reshape((-1, channel, out_h, out_w)) - + def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): - out = upfirdn2d_native( - input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] - ) - + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], + pad[0], pad[1]) + return out def make_kernel(k): k = paddle.to_tensor(k, dtype='float32') - + if k.ndim == 1: k = k.unsqueeze(0) * k.unsqueeze(1) - + k /= k.sum() - + return k - - + + class Upfirdn2dUpsample(nn.Layer): def __init__(self, kernel, factor=2): super().__init__() - + self.factor = factor - kernel = make_kernel(kernel) * (factor ** 2) + kernel = make_kernel(kernel) * (factor * factor) self.register_buffer("kernel", kernel) - + p = kernel.shape[0] - factor - + pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 - + self.pad = (pad0, pad1) - + def forward(self, input): - out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) - + out = upfirdn2d(input, + self.kernel, + up=self.factor, + down=1, + pad=self.pad) + return out - - + + class Upfirdn2dDownsample(nn.Layer): def __init__(self, kernel, factor=2): super().__init__() - + self.factor = factor kernel = make_kernel(kernel) self.register_buffer("kernel", kernel) - + p = kernel.shape[0] - factor - + pad0 = (p + 1) // 2 pad1 = p // 2 - + self.pad = (pad0, pad1) - + def forward(self, input): - out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) - + out = upfirdn2d(input, + self.kernel, + up=1, + down=self.factor, + pad=self.pad) + return out - - + + class Upfirdn2dBlur(nn.Layer): def __init__(self, kernel, pad, upsample_factor=1): super().__init__() - + kernel = make_kernel(kernel) - + if upsample_factor > 1: - kernel = kernel * (upsample_factor ** 2) - - self.register_buffer("kernel", kernel) - + kernel = kernel * (upsample_factor * upsample_factor) + + self.register_buffer("kernel", kernel, persistable=False) + self.pad = pad - + def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) - + return out diff --git a/ppgan/utils/audio.py b/ppgan/utils/audio.py index 28634186cbf98eb68e721a437a3048ca591e72c6..9cf1e74208e9552062bf0922b26e4b6433e3b5b2 100644 --- a/ppgan/utils/audio.py +++ b/ppgan/utils/audio.py @@ -1,14 +1,14 @@ -import librosa -import librosa.filters import numpy as np from scipy import signal from scipy.io import wavfile +from paddle.utils import try_import from .audio_config import get_audio_config audio_config = get_audio_config() def load_wav(path, sr): + librosa = try_import('librosa') return librosa.core.load(path, sr=sr)[0] @@ -19,6 +19,7 @@ def save_wav(wav, path, sr): def save_wavenet_wav(wav, path, sr): + librosa = try_import('librosa') librosa.output.write_wav(path, wav, sr=sr) @@ -75,6 +76,7 @@ def _stft(y): if audio_config.use_lws: return _lws_processor(audio_config).stft(y).T else: + librosa = try_import('librosa') return librosa.stft(y=y, n_fft=audio_config.n_fft, hop_length=get_hop_size(), @@ -123,6 +125,7 @@ def _linear_to_mel(spectogram): def _build_mel_basis(): assert audio_config.fmax <= audio_config.sample_rate // 2 + librosa = try_import('librosa') return librosa.filters.mel(audio_config.sample_rate, audio_config.n_fft, n_mels=audio_config.num_mels,