未验证 提交 530a6a8c 编写于 作者: L LielinJiang 提交者: GitHub

Add train code of stylegan2 (#149)

* add stylegan model
上级 e13e1c18
total_iters: 800000
output_dir: output_dir
model:
name: StyleGAN2Model
generator:
name: StyleGANv2Generator
size: 256
style_dim: 512
n_mlp: 8
discriminator:
name: StyleGANv2Discriminator
size: 256
gan_criterion:
name: GANLoss
gan_mode: logistic
loss_weight: !!float 1
# r1 regularization for discriminator
r1_reg_weight: 10.
# path length regularization for generator
path_batch_shrink: 2.
path_reg_weight: 2.
params:
gen_iters: 4
disc_iters: 16
dataset:
train:
name: SingleDataset
dataroot: data/ffhq/images256x256/
num_workers: 3
batch_size: 3
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: RandomHorizontalFlip
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.002
milestones: [600000]
gamma: 0.5
optimizer:
optimG:
name: Adam
beta1: 0.0
beta2: 0.792
net_names:
- gen
optimD:
name: Adam
net_names:
- disc
beta1: 0.0
beta2: 0.9317647058823529
log_config:
interval: 50
visiual_interval: 500
snapshot_config:
interval: 5000
......@@ -59,19 +59,21 @@ class Transforms():
data = tuple(data)
for transform in self.transforms:
data = transform(data)
if hasattr(transform, 'params') and isinstance(
transform.params, dict):
datas.update(transform.params)
if len(self.input_keys) > 1:
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
else:
datas[k] = data
if self.output_keys is not None:
for i, k in enumerate(self.output_keys):
datas[k] = data[i]
return datas
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
return datas
......
......@@ -27,7 +27,7 @@ class SingleDataset(BaseDataset):
dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config.
"""
super(SingleDataset).__init__(self, preprocess)
super(SingleDataset, self).__init__(preprocess)
self.dataroot = dataroot
self.data_infos = self.prepare_data_infos()
......
......@@ -123,6 +123,8 @@ class Trainer:
self.batch_id = 0
self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
if self.by_epoch:
......@@ -143,6 +145,17 @@ class Trainer:
for net_name, net in self.model.nets.items():
self.model.nets[net_name] = paddle.DataParallel(net, strategy)
def learning_rate_scheduler_step(self):
if isinstance(self.model.lr_scheduler, dict):
for lr_scheduler in self.model.lr_scheduler.values():
lr_scheduler.step()
elif isinstance(self.model.lr_scheduler,
paddle.optimizer.lr.LRScheduler):
self.model.lr_scheduler.step()
else:
raise ValueError(
'lr schedulter must be a dict or an instance of LRScheduler')
def train(self):
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
......@@ -179,7 +192,7 @@ class Trainer:
if self.current_iter % self.visual_interval == 0:
self.visual('visual_train')
self.model.lr_scheduler.step()
self.learning_rate_scheduler_step()
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
self.test()
......
......@@ -22,5 +22,6 @@ from .esrgan_model import ESRGAN
from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
from .styleganv2_model import StyleGAN2Model
from .wav2lip_model import Wav2LipModel
from .wav2lip_hq_model import Wav2LipModelHq
......@@ -35,22 +35,22 @@ class ConvLayer(nn.Sequential):
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2D(
in_channel,
......@@ -59,41 +59,58 @@ class ConvLayer(nn.Sequential):
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)
))
if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias))
super().__init__(*layers)
class ResBlock(nn.Layer):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
)
self.skip = ConvLayer(in_channel,
out_channel,
1,
downsample=True,
activate=False,
bias=False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
# temporally solve pow double grad problem
def var(x, axis=None, unbiased=True, keepdim=False, name=None):
u = paddle.mean(x, axis, True, name)
out = paddle.sum((x - u) * (x - u), axis, keepdim=keepdim, name=name)
n = paddle.cast(paddle.numel(x), x.dtype) \
/ paddle.cast(paddle.numel(out), x.dtype)
if unbiased:
one_const = paddle.ones([1], x.dtype)
n = paddle.where(n > one_const, n - 1., one_const)
out /= n
return out
@DISCRIMINATORS.register()
class StyleGANv2Discriminator(nn.Layer):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__()
channels = {
4: 512,
8: 512,
......@@ -105,47 +122,48 @@ class StyleGANv2Discriminator(nn.Layer):
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
out_channel = channels[2**(i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4] * 4 * 4,
channels[4],
activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)
def forward(self, input):
out = self.convs(input)
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.reshape((
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
))
stddev = paddle.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = out.reshape((group, -1, self.stddev_feat,
channel // self.stddev_feat, height, width))
stddev = paddle.sqrt(var(stddev, 0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
stddev = stddev.tile((group, 1, height, width))
out = paddle.concat([out, stddev], 1)
out = self.final_conv(out)
out = out.reshape((batch, -1))
out = self.final_linear(out)
return out
......@@ -27,11 +27,12 @@ from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur
class PixelNorm(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, input):
return input * paddle.rsqrt(paddle.mean(input ** 2, 1, keepdim=True) + 1e-8)
return input * paddle.rsqrt(
paddle.mean(input * input, 1, keepdim=True) + 1e-8)
class ModulatedConv2D(nn.Layer):
def __init__(
self,
......@@ -45,75 +46,78 @@ class ModulatedConv2D(nn.Layer):
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
self.blur = Upfirdn2dBlur(blur_kernel,
pad=(pad0, pad1),
upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2
fan_in = in_channel * (kernel_size * kernel_size)
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = self.create_parameter(
(1, out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal()
)
(1, out_channel, in_channel, kernel_size, kernel_size),
default_initializer=nn.initializer.Normal())
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def __repr__(self):
return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})"
)
f"upsample={self.upsample}, downsample={self.downsample})")
def forward(self, input, style):
batch, in_channel, height, width = input.shape
style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
weight = self.scale * self.weight * style
if self.demodulate:
demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
weight = weight.reshape((
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
))
weight = weight.reshape((batch * self.out_channel, in_channel,
self.kernel_size, self.kernel_size))
if self.upsample:
input = input.reshape((1, batch * in_channel, height, width))
weight = weight.reshape((
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
))
weight = weight.transpose((0, 2, 1, 3, 4)).reshape((
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
))
out = F.conv2d_transpose(input, weight, padding=0, stride=2, groups=batch)
weight = weight.reshape((batch, self.out_channel, in_channel,
self.kernel_size, self.kernel_size))
weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
(batch * in_channel, self.out_channel, self.kernel_size,
self.kernel_size))
out = F.conv2d_transpose(input,
weight,
padding=0,
stride=2,
groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
......@@ -121,43 +125,46 @@ class ModulatedConv2D(nn.Layer):
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))
else:
input = input.reshape((1, batch * in_channel, height, width))
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))
return out
class NoiseInjection(nn.Layer):
def __init__(self):
super().__init__()
self.weight = self.create_parameter((1,), default_initializer=nn.initializer.Constant(0.0))
self.weight = self.create_parameter(
(1, ), default_initializer=nn.initializer.Constant(0.0))
def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = paddle.randn((batch, 1, height, width))
return image + self.weight * noise
class ConstantInput(nn.Layer):
def __init__(self, channel, size=4):
super().__init__()
self.input = self.create_parameter((1, channel, size, size), default_initializer=nn.initializer.Normal())
self.input = self.create_parameter(
(1, channel, size, size),
default_initializer=nn.initializer.Normal())
def forward(self, input):
batch = input.shape[0]
out = self.input.tile((batch, 1, 1, 1))
return out
class StyledConv(nn.Layer):
def __init__(
self,
......@@ -170,7 +177,7 @@ class StyledConv(nn.Layer):
demodulate=True,
):
super().__init__()
self.conv = ModulatedConv2D(
in_channel,
out_channel,
......@@ -180,40 +187,49 @@ class StyledConv(nn.Layer):
blur_kernel=blur_kernel,
demodulate=demodulate,
)
self.noise = NoiseInjection()
self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style, noise=None):
out = self.conv(input, style)
out = self.noise(out, noise=noise)
out = self.activate(out)
return out
class ToRGB(nn.Layer):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
def __init__(self,
in_channel,
style_dim,
upsample=True,
blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upfirdn2dUpsample(blur_kernel)
self.conv = ModulatedConv2D(in_channel, 3, 1, style_dim, demodulate=False)
self.bias = self.create_parameter((1, 3, 1, 1), nn.initializer.Constant(0.0))
self.conv = ModulatedConv2D(in_channel,
3,
1,
style_dim,
demodulate=False)
self.bias = self.create_parameter((1, 3, 1, 1),
nn.initializer.Constant(0.0))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
@GENERATORS.register()
class StyleGANv2Generator(nn.Layer):
def __init__(
......@@ -226,22 +242,22 @@ class StyleGANv2Generator(nn.Layer):
lr_mlp=0.01,
):
super().__init__()
self.size = size
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
)
)
EqualLinear(style_dim,
style_dim,
lr_mul=lr_mlp,
activation="fused_lrelu"))
self.style = nn.Sequential(*layers)
self.channels = {
4: 512,
8: 512,
......@@ -253,31 +269,34 @@ class StyleGANv2Generator(nn.Layer):
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.conv1 = StyledConv(self.channels[4],
self.channels[4],
3,
style_dim,
blur_kernel=blur_kernel)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.LayerList()
self.upsamples = nn.LayerList()
self.to_rgbs = nn.LayerList()
self.noises = nn.Layer()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res]
self.noises.register_buffer(f"noise_{layer_idx}", paddle.randn(shape))
shape = [1, 1, 2**res, 2**res]
self.noises.register_buffer(f"noise_{layer_idx}",
paddle.randn(shape))
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
out_channel = self.channels[2**i]
self.convs.append(
StyledConv(
in_channel,
......@@ -286,41 +305,39 @@ class StyleGANv2Generator(nn.Layer):
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
))
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
StyledConv(out_channel,
out_channel,
3,
style_dim,
blur_kernel=blur_kernel))
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
def make_noise(self):
noises = [paddle.randn((1, 1, 2 ** 2, 2 ** 2))]
noises = [paddle.randn((1, 1, 2**2, 2**2))]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(paddle.randn((1, 1, 2 ** i, 2 ** i)))
noises.append(paddle.randn((1, 1, 2**i, 2**i)))
return noises
def mean_latent(self, n_latent):
latent_in = paddle.randn((
n_latent, self.style_dim
))
latent_in = paddle.randn((n_latent, self.style_dim))
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def forward(
self,
styles,
......@@ -334,62 +351,65 @@ class StyleGANv2Generator(nn.Layer):
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
getattr(self.noises, f"noise_{i}")
for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
latent2 = styles[1].unsqueeze(1).tile((1, self.n_latent - inject_index, 1))
latent2 = styles[1].unsqueeze(1).tile(
(1, self.n_latent - inject_index, 1))
latent = paddle.concat([latent, latent2], 1)
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2],
self.convs[1::2],
noise[1::2],
noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import paddle
import paddle.nn as nn
from .base_model import BaseModel
from .builder import MODELS
from .criterions import build_criterion
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer
def r1_penalty(real_pred, real_img):
"""
R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = paddle.grad(outputs=real_pred.sum(),
inputs=real_img,
create_graph=True)[0]
grad_penalty = (grad_real * grad_real).reshape([grad_real.shape[0],
-1]).sum(1).mean()
return grad_penalty
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = paddle.randn(fake_img.shape) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3])
grad = paddle.grad(outputs=(fake_img * noise).sum(),
inputs=latents,
create_graph=True)[0]
path_lengths = paddle.sqrt((grad * grad).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() -
mean_path_length)
path_penalty = ((path_lengths - path_mean) *
(path_lengths - path_mean)).mean()
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
@MODELS.register()
class StyleGAN2Model(BaseModel):
"""
This class implements the StyleGANV2 model, for learning image-to-image translation without paired data.
StyleGAN2 paper: https://arxiv.org/pdf/1912.04958.pdf
"""
def __init__(self,
generator,
discriminator=None,
gan_criterion=None,
num_style_feat=512,
mixing_prob=0.9,
r1_reg_weight=10.,
path_reg_weight=2.,
path_batch_shrink=2.,
params=None):
"""Initialize the CycleGAN class.
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
gan_criterion (dict): config of gan criterion.
"""
super(StyleGAN2Model, self).__init__(params)
self.gen_iters = 4 if self.params is None else self.params.get(
'gen_iters', 4)
self.disc_iters = 16 if self.params is None else self.params.get(
'disc_iters', 16)
self.disc_start_iters = (0 if self.params is None else self.params.get(
'disc_start_iters', 0))
self.visual_iters = (500 if self.params is None else self.params.get(
'visual_iters', 500))
self.mixing_prob = mixing_prob
self.num_style_feat = num_style_feat
self.r1_reg_weight = r1_reg_weight
self.path_reg_weight = path_reg_weight
self.path_batch_shrink = path_batch_shrink
self.mean_path_length = 0
self.nets['gen'] = build_generator(generator)
# define discriminators
if discriminator:
self.nets['disc'] = build_discriminator(discriminator)
self.nets['gen_ema'] = build_generator(generator)
self.model_ema(0)
self.nets['gen'].train()
self.nets['gen_ema'].eval()
self.nets['disc'].train()
self.current_iter = 1
# define loss functions
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
def setup_lr_schedulers(self, cfg):
self.lr_scheduler = dict()
gen_cfg = cfg.copy()
net_g_reg_ratio = self.gen_iters / (self.gen_iters + 1)
gen_cfg['learning_rate'] = cfg['learning_rate'] * net_g_reg_ratio
self.lr_scheduler['gen'] = build_lr_scheduler(gen_cfg)
disc_cfg = cfg.copy()
net_d_reg_ratio = self.disc_iters / (self.disc_iters + 1)
disc_cfg['learning_rate'] = cfg['learning_rate'] * net_d_reg_ratio
self.lr_scheduler['disc'] = build_lr_scheduler(disc_cfg)
return self.lr_scheduler
def setup_optimizers(self, lr, cfg):
for opt_name, opt_cfg in cfg.items():
if opt_name == 'optimG':
_lr = lr['gen']
elif opt_name == 'optimD':
_lr = lr['disc']
else:
raise ValueError("opt name must be in ['optimG', optimD]")
cfg_ = opt_cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
self.optimizers[opt_name] = build_optimizer(cfg_, _lr, parameters)
return self.optimizers
def get_bare_model(self, net):
"""Get bare model, especially under wrapping with DataParallel.
"""
if isinstance(net, (paddle.DataParallel)):
net = net._layers
return net
def model_ema(self, decay=0.999):
net_g = self.get_bare_model(self.nets['gen'])
net_g_params = dict(net_g.named_parameters())
neg_g_ema = self.get_bare_model(self.nets['gen_ema'])
net_g_ema_params = dict(neg_g_ema.named_parameters())
for k in net_g_ema_params.keys():
net_g_ema_params[k].set_value(net_g_ema_params[k] * (decay) +
(net_g_params[k] * (1 - decay)))
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Args:
input (dict): include the data itself and its metadata information.
"""
self.real_img = paddle.fluid.dygraph.to_variable(input['A'])
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
def make_noise(self, batch, num_noise):
if num_noise == 1:
noises = paddle.randn([batch, self.num_style_feat])
else:
noises = []
for _ in range(num_noise):
noises.append(paddle.randn([batch, self.num_style_feat]))
return noises
def mixing_noise(self, batch, prob):
if random.random() < prob:
return self.make_noise(batch, 2)
else:
return [self.make_noise(batch, 1)]
def train_iter(self, optimizers=None):
current_iter = self.current_iter
self.set_requires_grad(self.nets['disc'], True)
optimizers['optimD'].clear_grad()
batch = self.real_img.shape[0]
noise = self.mixing_noise(batch, self.mixing_prob)
fake_img, _ = self.nets['gen'](noise)
self.visual_items['real_img'] = self.real_img
self.visual_items['fake_img'] = fake_img
fake_pred = self.nets['disc'](fake_img.detach())
real_pred = self.nets['disc'](self.real_img)
# wgan loss with softplus (logistic loss) for discriminator
l_d_total = 0.
l_d = self.gan_criterion(real_pred, True,
is_disc=True) + self.gan_criterion(
fake_pred, False, is_disc=True)
self.losses['l_d'] = l_d
# In wgan, real_score should be positive and fake_score should be
# negative
self.losses['real_score'] = real_pred.detach().mean()
self.losses['fake_score'] = fake_pred.detach().mean()
l_d_total += l_d
if current_iter % self.disc_iters == 0:
self.real_img.stop_gradient = False
real_pred = self.nets['disc'](self.real_img)
l_d_r1 = r1_penalty(real_pred, self.real_img)
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.disc_iters +
0 * real_pred[0])
self.losses['l_d_r1'] = l_d_r1.detach().mean()
l_d_total += l_d_r1
l_d_total.backward()
optimizers['optimD'].step()
self.set_requires_grad(self.nets['disc'], False)
optimizers['optimG'].clear_grad()
noise = self.mixing_noise(batch, self.mixing_prob)
fake_img, _ = self.nets['gen'](noise)
fake_pred = self.nets['disc'](fake_img)
# wgan loss with softplus (non-saturating loss) for generator
l_g_total = 0.
l_g = self.gan_criterion(fake_pred, True, is_disc=False)
self.losses['l_g'] = l_g
l_g_total += l_g
if current_iter % self.gen_iters == 0:
path_batch_size = max(1, batch // self.path_batch_shrink)
noise = self.mixing_noise(path_batch_size, self.mixing_prob)
fake_img, latents = self.nets['gen'](noise, return_latents=True)
l_g_path, path_lengths, self.mean_path_length = g_path_regularize(
fake_img, latents, self.mean_path_length)
l_g_path = (self.path_reg_weight * self.gen_iters * l_g_path +
0 * fake_img[0, 0, 0, 0])
l_g_total += l_g_path
self.losses['l_g_path'] = l_g_path.detach().mean()
self.losses['path_length'] = path_lengths
l_g_total.backward()
optimizers['optimG'].step()
# EMA
self.model_ema(decay=0.5**(32 / (10 * 1000)))
if self.current_iter % self.visual_iters:
sample_z = [self.make_noise(1, 1)]
sample, _ = self.nets['gen_ema'](sample_z)
self.visual_items['fake_img_ema'] = sample
self.current_iter += 1
......@@ -24,25 +24,30 @@ class EqualConv2D(nn.Layer):
"""This convolutional layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training.
"""
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
def __init__(self,
in_channel,
out_channel,
kernel_size,
stride=1,
padding=0,
bias=True):
super().__init__()
self.weight = self.create_parameter(
(out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal()
)
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
(out_channel, in_channel, kernel_size, kernel_size),
default_initializer=nn.initializer.Normal())
self.scale = 1 / math.sqrt(in_channel * (kernel_size * kernel_size))
self.stride = stride
self.padding = padding
if bias:
self.bias = self.create_parameter((out_channel,), nn.initializer.Constant(0.0))
self.bias = self.create_parameter((out_channel, ),
nn.initializer.Constant(0.0))
else:
self.bias = None
def forward(self, input):
out = F.conv2d(
input,
......@@ -51,51 +56,57 @@ class EqualConv2D(nn.Layer):
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
)
class EqualLinear(nn.Layer):
"""This linear layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training.
"""
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
def __init__(self,
in_dim,
out_dim,
bias=True,
bias_init=0,
lr_mul=1,
activation=None):
super().__init__()
self.weight = self.create_parameter((in_dim, out_dim), default_initializer=nn.initializer.Normal())
self.weight[:] = (self.weight / lr_mul).detach()
self.weight = self.create_parameter(
(in_dim, out_dim), default_initializer=nn.initializer.Normal())
self.weight.set_value((self.weight / lr_mul))
if bias:
self.bias = self.create_parameter((out_dim,), nn.initializer.Constant(bias_init))
self.bias = self.create_parameter(
(out_dim, ), nn.initializer.Constant(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
out = F.linear(input,
self.weight * self.scale,
bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})"
......
......@@ -15,37 +15,35 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape((-1, in_h, in_w, 1))
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.reshape((-1, in_h, 1, in_w, 1, minor))
out = out.transpose((0,1,3,5,2,4))
out = out.reshape((-1,1,1,1))
out = out.transpose((0, 1, 3, 5, 2, 4))
out = out.reshape((-1, 1, 1, 1))
out = F.pad(out, [0, up_x - 1, 0, up_y - 1])
out = out.reshape((-1, in_h, in_w, minor, up_y, up_x))
out = out.transpose((0,3,1,4,2,5))
out = out.transpose((0, 3, 1, 4, 2, 5))
out = out.reshape((-1, minor, in_h * up_y, in_w * up_x))
out = F.pad(
out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,:,
max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0),
]
out = out.reshape((
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
))
out, [max(pad_x0, 0),
max(pad_x1, 0),
max(pad_y0, 0),
max(pad_y1, 0)])
out = out[:, :,
max(-pad_y0, 0):out.shape[2] - max(-pad_y1, 0),
max(-pad_x0, 0):out.shape[3] - max(-pad_x1, 0), ]
out = out.reshape(
([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]))
w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w))
out = F.conv2d(out, w)
out = out.reshape((
......@@ -56,88 +54,95 @@ def upfirdn2d_native(
))
out = out.transpose((0, 2, 3, 1))
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.reshape((-1, channel, out_h, out_w))
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
out = upfirdn2d_native(
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1],
pad[0], pad[1])
return out
def make_kernel(k):
k = paddle.to_tensor(k, dtype='float32')
if k.ndim == 1:
k = k.unsqueeze(0) * k.unsqueeze(1)
k /= k.sum()
return k
class Upfirdn2dUpsample(nn.Layer):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
kernel = make_kernel(kernel) * (factor * factor)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
out = upfirdn2d(input,
self.kernel,
up=self.factor,
down=1,
pad=self.pad)
return out
class Upfirdn2dDownsample(nn.Layer):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
out = upfirdn2d(input,
self.kernel,
up=1,
down=self.factor,
pad=self.pad)
return out
class Upfirdn2dBlur(nn.Layer):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer("kernel", kernel)
kernel = kernel * (upsample_factor * upsample_factor)
self.register_buffer("kernel", kernel, persistable=False)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
import librosa
import librosa.filters
import numpy as np
from scipy import signal
from scipy.io import wavfile
from paddle.utils import try_import
from .audio_config import get_audio_config
audio_config = get_audio_config()
def load_wav(path, sr):
librosa = try_import('librosa')
return librosa.core.load(path, sr=sr)[0]
......@@ -19,6 +19,7 @@ def save_wav(wav, path, sr):
def save_wavenet_wav(wav, path, sr):
librosa = try_import('librosa')
librosa.output.write_wav(path, wav, sr=sr)
......@@ -75,6 +76,7 @@ def _stft(y):
if audio_config.use_lws:
return _lws_processor(audio_config).stft(y).T
else:
librosa = try_import('librosa')
return librosa.stft(y=y,
n_fft=audio_config.n_fft,
hop_length=get_hop_size(),
......@@ -123,6 +125,7 @@ def _linear_to_mel(spectogram):
def _build_mel_basis():
assert audio_config.fmax <= audio_config.sample_rate // 2
librosa = try_import('librosa')
return librosa.filters.mel(audio_config.sample_rate,
audio_config.n_fft,
n_mels=audio_config.num_mels,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册