未验证 提交 530a6a8c 编写于 作者: L LielinJiang 提交者: GitHub

Add train code of stylegan2 (#149)

* add stylegan model
上级 e13e1c18
total_iters: 800000
output_dir: output_dir
model:
name: StyleGAN2Model
generator:
name: StyleGANv2Generator
size: 256
style_dim: 512
n_mlp: 8
discriminator:
name: StyleGANv2Discriminator
size: 256
gan_criterion:
name: GANLoss
gan_mode: logistic
loss_weight: !!float 1
# r1 regularization for discriminator
r1_reg_weight: 10.
# path length regularization for generator
path_batch_shrink: 2.
path_reg_weight: 2.
params:
gen_iters: 4
disc_iters: 16
dataset:
train:
name: SingleDataset
dataroot: data/ffhq/images256x256/
num_workers: 3
batch_size: 3
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: RandomHorizontalFlip
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.002
milestones: [600000]
gamma: 0.5
optimizer:
optimG:
name: Adam
beta1: 0.0
beta2: 0.792
net_names:
- gen
optimD:
name: Adam
net_names:
- disc
beta1: 0.0
beta2: 0.9317647058823529
log_config:
interval: 50
visiual_interval: 500
snapshot_config:
interval: 5000
...@@ -59,19 +59,21 @@ class Transforms(): ...@@ -59,19 +59,21 @@ class Transforms():
data = tuple(data) data = tuple(data)
for transform in self.transforms: for transform in self.transforms:
data = transform(data) data = transform(data)
if hasattr(transform, 'params') and isinstance( if hasattr(transform, 'params') and isinstance(
transform.params, dict): transform.params, dict):
datas.update(transform.params) datas.update(transform.params)
if len(self.input_keys) > 1:
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
else:
datas[k] = data
if self.output_keys is not None: if self.output_keys is not None:
for i, k in enumerate(self.output_keys): for i, k in enumerate(self.output_keys):
datas[k] = data[i] datas[k] = data[i]
return datas return datas
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
return datas return datas
......
...@@ -27,7 +27,7 @@ class SingleDataset(BaseDataset): ...@@ -27,7 +27,7 @@ class SingleDataset(BaseDataset):
dataroot (str): Directory of dataset. dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config. preprocess (list[dict]): A sequence of data preprocess config.
""" """
super(SingleDataset).__init__(self, preprocess) super(SingleDataset, self).__init__(preprocess)
self.dataroot = dataroot self.dataroot = dataroot
self.data_infos = self.prepare_data_infos() self.data_infos = self.prepare_data_infos()
......
...@@ -123,6 +123,8 @@ class Trainer: ...@@ -123,6 +123,8 @@ class Trainer:
self.batch_id = 0 self.batch_id = 0
self.global_steps = 0 self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval self.weight_interval = cfg.snapshot_config.interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.log_interval = cfg.log_config.interval self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval self.visual_interval = cfg.log_config.visiual_interval
if self.by_epoch: if self.by_epoch:
...@@ -143,6 +145,17 @@ class Trainer: ...@@ -143,6 +145,17 @@ class Trainer:
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
self.model.nets[net_name] = paddle.DataParallel(net, strategy) self.model.nets[net_name] = paddle.DataParallel(net, strategy)
def learning_rate_scheduler_step(self):
if isinstance(self.model.lr_scheduler, dict):
for lr_scheduler in self.model.lr_scheduler.values():
lr_scheduler.step()
elif isinstance(self.model.lr_scheduler,
paddle.optimizer.lr.LRScheduler):
self.model.lr_scheduler.step()
else:
raise ValueError(
'lr schedulter must be a dict or an instance of LRScheduler')
def train(self): def train(self):
reader_cost_averager = TimeAverager() reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager() batch_cost_averager = TimeAverager()
...@@ -179,7 +192,7 @@ class Trainer: ...@@ -179,7 +192,7 @@ class Trainer:
if self.current_iter % self.visual_interval == 0: if self.current_iter % self.visual_interval == 0:
self.visual('visual_train') self.visual('visual_train')
self.model.lr_scheduler.step() self.learning_rate_scheduler_step()
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
self.test() self.test()
......
...@@ -22,5 +22,6 @@ from .esrgan_model import ESRGAN ...@@ -22,5 +22,6 @@ from .esrgan_model import ESRGAN
from .ugatit_model import UGATITModel from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel from .dc_gan_model import DCGANModel
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
from .styleganv2_model import StyleGAN2Model
from .wav2lip_model import Wav2LipModel from .wav2lip_model import Wav2LipModel
from .wav2lip_hq_model import Wav2LipModelHq from .wav2lip_hq_model import Wav2LipModelHq
...@@ -59,8 +59,7 @@ class ConvLayer(nn.Sequential): ...@@ -59,8 +59,7 @@ class ConvLayer(nn.Sequential):
padding=self.padding, padding=self.padding,
stride=stride, stride=stride,
bias=bias and not activate, bias=bias and not activate,
) ))
)
if activate: if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias)) layers.append(FusedLeakyReLU(out_channel, bias=bias))
...@@ -75,9 +74,12 @@ class ResBlock(nn.Layer): ...@@ -75,9 +74,12 @@ class ResBlock(nn.Layer):
self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer( self.skip = ConvLayer(in_channel,
in_channel, out_channel, 1, downsample=True, activate=False, bias=False out_channel,
) 1,
downsample=True,
activate=False,
bias=False)
def forward(self, input): def forward(self, input):
out = self.conv1(input) out = self.conv1(input)
...@@ -89,6 +91,21 @@ class ResBlock(nn.Layer): ...@@ -89,6 +91,21 @@ class ResBlock(nn.Layer):
return out return out
# temporally solve pow double grad problem
def var(x, axis=None, unbiased=True, keepdim=False, name=None):
u = paddle.mean(x, axis, True, name)
out = paddle.sum((x - u) * (x - u), axis, keepdim=keepdim, name=name)
n = paddle.cast(paddle.numel(x), x.dtype) \
/ paddle.cast(paddle.numel(out), x.dtype)
if unbiased:
one_const = paddle.ones([1], x.dtype)
n = paddle.where(n > one_const, n - 1., one_const)
out /= n
return out
@DISCRIMINATORS.register() @DISCRIMINATORS.register()
class StyleGANv2Discriminator(nn.Layer): class StyleGANv2Discriminator(nn.Layer):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
...@@ -113,7 +130,7 @@ class StyleGANv2Discriminator(nn.Layer): ...@@ -113,7 +130,7 @@ class StyleGANv2Discriminator(nn.Layer):
in_channel = channels[size] in_channel = channels[size]
for i in range(log_size, 2, -1): for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)] out_channel = channels[2**(i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel)) convs.append(ResBlock(in_channel, out_channel, blur_kernel))
...@@ -126,7 +143,9 @@ class StyleGANv2Discriminator(nn.Layer): ...@@ -126,7 +143,9 @@ class StyleGANv2Discriminator(nn.Layer):
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential( self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), EqualLinear(channels[4] * 4 * 4,
channels[4],
activation="fused_lrelu"),
EqualLinear(channels[4], 1), EqualLinear(channels[4], 1),
) )
...@@ -135,10 +154,9 @@ class StyleGANv2Discriminator(nn.Layer): ...@@ -135,10 +154,9 @@ class StyleGANv2Discriminator(nn.Layer):
batch, channel, height, width = out.shape batch, channel, height, width = out.shape
group = min(batch, self.stddev_group) group = min(batch, self.stddev_group)
stddev = out.reshape(( stddev = out.reshape((group, -1, self.stddev_feat,
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width channel // self.stddev_feat, height, width))
)) stddev = paddle.sqrt(var(stddev, 0, unbiased=False) + 1e-8)
stddev = paddle.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2) stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
stddev = stddev.tile((group, 1, height, width)) stddev = stddev.tile((group, 1, height, width))
out = paddle.concat([out, stddev], 1) out = paddle.concat([out, stddev], 1)
......
...@@ -29,7 +29,8 @@ class PixelNorm(nn.Layer): ...@@ -29,7 +29,8 @@ class PixelNorm(nn.Layer):
super().__init__() super().__init__()
def forward(self, input): def forward(self, input):
return input * paddle.rsqrt(paddle.mean(input ** 2, 1, keepdim=True) + 1e-8) return input * paddle.rsqrt(
paddle.mean(input * input, 1, keepdim=True) + 1e-8)
class ModulatedConv2D(nn.Layer): class ModulatedConv2D(nn.Layer):
...@@ -59,7 +60,9 @@ class ModulatedConv2D(nn.Layer): ...@@ -59,7 +60,9 @@ class ModulatedConv2D(nn.Layer):
pad0 = (p + 1) // 2 + factor - 1 pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1 pad1 = p // 2 + 1
self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) self.blur = Upfirdn2dBlur(blur_kernel,
pad=(pad0, pad1),
upsample_factor=factor)
if downsample: if downsample:
factor = 2 factor = 2
...@@ -69,13 +72,13 @@ class ModulatedConv2D(nn.Layer): ...@@ -69,13 +72,13 @@ class ModulatedConv2D(nn.Layer):
self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)) self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2 fan_in = in_channel * (kernel_size * kernel_size)
self.scale = 1 / math.sqrt(fan_in) self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2 self.padding = kernel_size // 2
self.weight = self.create_parameter( self.weight = self.create_parameter(
(1, out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal() (1, out_channel, in_channel, kernel_size, kernel_size),
) default_initializer=nn.initializer.Normal())
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
...@@ -84,8 +87,7 @@ class ModulatedConv2D(nn.Layer): ...@@ -84,8 +87,7 @@ class ModulatedConv2D(nn.Layer):
def __repr__(self): def __repr__(self):
return ( return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})" f"upsample={self.upsample}, downsample={self.downsample})")
)
def forward(self, input, style): def forward(self, input, style):
batch, in_channel, height, width = input.shape batch, in_channel, height, width = input.shape
...@@ -94,22 +96,24 @@ class ModulatedConv2D(nn.Layer): ...@@ -94,22 +96,24 @@ class ModulatedConv2D(nn.Layer):
weight = self.scale * self.weight * style weight = self.scale * self.weight * style
if self.demodulate: if self.demodulate:
demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1)) weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
weight = weight.reshape(( weight = weight.reshape((batch * self.out_channel, in_channel,
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size self.kernel_size, self.kernel_size))
))
if self.upsample: if self.upsample:
input = input.reshape((1, batch * in_channel, height, width)) input = input.reshape((1, batch * in_channel, height, width))
weight = weight.reshape(( weight = weight.reshape((batch, self.out_channel, in_channel,
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size self.kernel_size, self.kernel_size))
)) weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
weight = weight.transpose((0, 2, 1, 3, 4)).reshape(( (batch * in_channel, self.out_channel, self.kernel_size,
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size self.kernel_size))
)) out = F.conv2d_transpose(input,
out = F.conv2d_transpose(input, weight, padding=0, stride=2, groups=batch) weight,
padding=0,
stride=2,
groups=batch)
_, _, height, width = out.shape _, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width)) out = out.reshape((batch, self.out_channel, height, width))
out = self.blur(out) out = self.blur(out)
...@@ -135,7 +139,8 @@ class NoiseInjection(nn.Layer): ...@@ -135,7 +139,8 @@ class NoiseInjection(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.weight = self.create_parameter((1,), default_initializer=nn.initializer.Constant(0.0)) self.weight = self.create_parameter(
(1, ), default_initializer=nn.initializer.Constant(0.0))
def forward(self, image, noise=None): def forward(self, image, noise=None):
if noise is None: if noise is None:
...@@ -149,7 +154,9 @@ class ConstantInput(nn.Layer): ...@@ -149,7 +154,9 @@ class ConstantInput(nn.Layer):
def __init__(self, channel, size=4): def __init__(self, channel, size=4):
super().__init__() super().__init__()
self.input = self.create_parameter((1, channel, size, size), default_initializer=nn.initializer.Normal()) self.input = self.create_parameter(
(1, channel, size, size),
default_initializer=nn.initializer.Normal())
def forward(self, input): def forward(self, input):
batch = input.shape[0] batch = input.shape[0]
...@@ -193,14 +200,23 @@ class StyledConv(nn.Layer): ...@@ -193,14 +200,23 @@ class StyledConv(nn.Layer):
class ToRGB(nn.Layer): class ToRGB(nn.Layer):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): def __init__(self,
in_channel,
style_dim,
upsample=True,
blur_kernel=[1, 3, 3, 1]):
super().__init__() super().__init__()
if upsample: if upsample:
self.upsample = Upfirdn2dUpsample(blur_kernel) self.upsample = Upfirdn2dUpsample(blur_kernel)
self.conv = ModulatedConv2D(in_channel, 3, 1, style_dim, demodulate=False) self.conv = ModulatedConv2D(in_channel,
self.bias = self.create_parameter((1, 3, 1, 1), nn.initializer.Constant(0.0)) 3,
1,
style_dim,
demodulate=False)
self.bias = self.create_parameter((1, 3, 1, 1),
nn.initializer.Constant(0.0))
def forward(self, input, style, skip=None): def forward(self, input, style, skip=None):
out = self.conv(input, style) out = self.conv(input, style)
...@@ -235,10 +251,10 @@ class StyleGANv2Generator(nn.Layer): ...@@ -235,10 +251,10 @@ class StyleGANv2Generator(nn.Layer):
for i in range(n_mlp): for i in range(n_mlp):
layers.append( layers.append(
EqualLinear( EqualLinear(style_dim,
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" style_dim,
) lr_mul=lr_mlp,
) activation="fused_lrelu"))
self.style = nn.Sequential(*layers) self.style = nn.Sequential(*layers)
...@@ -255,9 +271,11 @@ class StyleGANv2Generator(nn.Layer): ...@@ -255,9 +271,11 @@ class StyleGANv2Generator(nn.Layer):
} }
self.input = ConstantInput(self.channels[4]) self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv( self.conv1 = StyledConv(self.channels[4],
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel self.channels[4],
) 3,
style_dim,
blur_kernel=blur_kernel)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2)) self.log_size = int(math.log(size, 2))
...@@ -272,11 +290,12 @@ class StyleGANv2Generator(nn.Layer): ...@@ -272,11 +290,12 @@ class StyleGANv2Generator(nn.Layer):
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2 res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res] shape = [1, 1, 2**res, 2**res]
self.noises.register_buffer(f"noise_{layer_idx}", paddle.randn(shape)) self.noises.register_buffer(f"noise_{layer_idx}",
paddle.randn(shape))
for i in range(3, self.log_size + 1): for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i] out_channel = self.channels[2**i]
self.convs.append( self.convs.append(
StyledConv( StyledConv(
...@@ -286,14 +305,14 @@ class StyleGANv2Generator(nn.Layer): ...@@ -286,14 +305,14 @@ class StyleGANv2Generator(nn.Layer):
style_dim, style_dim,
upsample=True, upsample=True,
blur_kernel=blur_kernel, blur_kernel=blur_kernel,
) ))
)
self.convs.append( self.convs.append(
StyledConv( StyledConv(out_channel,
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel out_channel,
) 3,
) style_dim,
blur_kernel=blur_kernel))
self.to_rgbs.append(ToRGB(out_channel, style_dim)) self.to_rgbs.append(ToRGB(out_channel, style_dim))
...@@ -302,18 +321,16 @@ class StyleGANv2Generator(nn.Layer): ...@@ -302,18 +321,16 @@ class StyleGANv2Generator(nn.Layer):
self.n_latent = self.log_size * 2 - 2 self.n_latent = self.log_size * 2 - 2
def make_noise(self): def make_noise(self):
noises = [paddle.randn((1, 1, 2 ** 2, 2 ** 2))] noises = [paddle.randn((1, 1, 2**2, 2**2))]
for i in range(3, self.log_size + 1): for i in range(3, self.log_size + 1):
for _ in range(2): for _ in range(2):
noises.append(paddle.randn((1, 1, 2 ** i, 2 ** i))) noises.append(paddle.randn((1, 1, 2**i, 2**i)))
return noises return noises
def mean_latent(self, n_latent): def mean_latent(self, n_latent):
latent_in = paddle.randn(( latent_in = paddle.randn((n_latent, self.style_dim))
n_latent, self.style_dim
))
latent = self.style(latent_in).mean(0, keepdim=True) latent = self.style(latent_in).mean(0, keepdim=True)
return latent return latent
...@@ -340,16 +357,16 @@ class StyleGANv2Generator(nn.Layer): ...@@ -340,16 +357,16 @@ class StyleGANv2Generator(nn.Layer):
noise = [None] * self.num_layers noise = [None] * self.num_layers
else: else:
noise = [ noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) getattr(self.noises, f"noise_{i}")
for i in range(self.num_layers)
] ]
if truncation < 1: if truncation < 1:
style_t = [] style_t = []
for style in styles: for style in styles:
style_t.append( style_t.append(truncation_latent + truncation *
truncation_latent + truncation * (style - truncation_latent) (style - truncation_latent))
)
styles = style_t styles = style_t
...@@ -367,7 +384,8 @@ class StyleGANv2Generator(nn.Layer): ...@@ -367,7 +384,8 @@ class StyleGANv2Generator(nn.Layer):
inject_index = random.randint(1, self.n_latent - 1) inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).tile((1, inject_index, 1)) latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
latent2 = styles[1].unsqueeze(1).tile((1, self.n_latent - inject_index, 1)) latent2 = styles[1].unsqueeze(1).tile(
(1, self.n_latent - inject_index, 1))
latent = paddle.concat([latent, latent2], 1) latent = paddle.concat([latent, latent2], 1)
...@@ -377,9 +395,11 @@ class StyleGANv2Generator(nn.Layer): ...@@ -377,9 +395,11 @@ class StyleGANv2Generator(nn.Layer):
skip = self.to_rgb1(out, latent[:, 1]) skip = self.to_rgb1(out, latent[:, 1])
i = 1 i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip( for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2],
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs self.convs[1::2],
): noise[1::2],
noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1) out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2) out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) skip = to_rgb(out, latent[:, i + 2], skip)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import paddle
import paddle.nn as nn
from .base_model import BaseModel
from .builder import MODELS
from .criterions import build_criterion
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer
def r1_penalty(real_pred, real_img):
"""
R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = paddle.grad(outputs=real_pred.sum(),
inputs=real_img,
create_graph=True)[0]
grad_penalty = (grad_real * grad_real).reshape([grad_real.shape[0],
-1]).sum(1).mean()
return grad_penalty
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = paddle.randn(fake_img.shape) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3])
grad = paddle.grad(outputs=(fake_img * noise).sum(),
inputs=latents,
create_graph=True)[0]
path_lengths = paddle.sqrt((grad * grad).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() -
mean_path_length)
path_penalty = ((path_lengths - path_mean) *
(path_lengths - path_mean)).mean()
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
@MODELS.register()
class StyleGAN2Model(BaseModel):
"""
This class implements the StyleGANV2 model, for learning image-to-image translation without paired data.
StyleGAN2 paper: https://arxiv.org/pdf/1912.04958.pdf
"""
def __init__(self,
generator,
discriminator=None,
gan_criterion=None,
num_style_feat=512,
mixing_prob=0.9,
r1_reg_weight=10.,
path_reg_weight=2.,
path_batch_shrink=2.,
params=None):
"""Initialize the CycleGAN class.
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
gan_criterion (dict): config of gan criterion.
"""
super(StyleGAN2Model, self).__init__(params)
self.gen_iters = 4 if self.params is None else self.params.get(
'gen_iters', 4)
self.disc_iters = 16 if self.params is None else self.params.get(
'disc_iters', 16)
self.disc_start_iters = (0 if self.params is None else self.params.get(
'disc_start_iters', 0))
self.visual_iters = (500 if self.params is None else self.params.get(
'visual_iters', 500))
self.mixing_prob = mixing_prob
self.num_style_feat = num_style_feat
self.r1_reg_weight = r1_reg_weight
self.path_reg_weight = path_reg_weight
self.path_batch_shrink = path_batch_shrink
self.mean_path_length = 0
self.nets['gen'] = build_generator(generator)
# define discriminators
if discriminator:
self.nets['disc'] = build_discriminator(discriminator)
self.nets['gen_ema'] = build_generator(generator)
self.model_ema(0)
self.nets['gen'].train()
self.nets['gen_ema'].eval()
self.nets['disc'].train()
self.current_iter = 1
# define loss functions
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
def setup_lr_schedulers(self, cfg):
self.lr_scheduler = dict()
gen_cfg = cfg.copy()
net_g_reg_ratio = self.gen_iters / (self.gen_iters + 1)
gen_cfg['learning_rate'] = cfg['learning_rate'] * net_g_reg_ratio
self.lr_scheduler['gen'] = build_lr_scheduler(gen_cfg)
disc_cfg = cfg.copy()
net_d_reg_ratio = self.disc_iters / (self.disc_iters + 1)
disc_cfg['learning_rate'] = cfg['learning_rate'] * net_d_reg_ratio
self.lr_scheduler['disc'] = build_lr_scheduler(disc_cfg)
return self.lr_scheduler
def setup_optimizers(self, lr, cfg):
for opt_name, opt_cfg in cfg.items():
if opt_name == 'optimG':
_lr = lr['gen']
elif opt_name == 'optimD':
_lr = lr['disc']
else:
raise ValueError("opt name must be in ['optimG', optimD]")
cfg_ = opt_cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
self.optimizers[opt_name] = build_optimizer(cfg_, _lr, parameters)
return self.optimizers
def get_bare_model(self, net):
"""Get bare model, especially under wrapping with DataParallel.
"""
if isinstance(net, (paddle.DataParallel)):
net = net._layers
return net
def model_ema(self, decay=0.999):
net_g = self.get_bare_model(self.nets['gen'])
net_g_params = dict(net_g.named_parameters())
neg_g_ema = self.get_bare_model(self.nets['gen_ema'])
net_g_ema_params = dict(neg_g_ema.named_parameters())
for k in net_g_ema_params.keys():
net_g_ema_params[k].set_value(net_g_ema_params[k] * (decay) +
(net_g_params[k] * (1 - decay)))
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Args:
input (dict): include the data itself and its metadata information.
"""
self.real_img = paddle.fluid.dygraph.to_variable(input['A'])
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
def make_noise(self, batch, num_noise):
if num_noise == 1:
noises = paddle.randn([batch, self.num_style_feat])
else:
noises = []
for _ in range(num_noise):
noises.append(paddle.randn([batch, self.num_style_feat]))
return noises
def mixing_noise(self, batch, prob):
if random.random() < prob:
return self.make_noise(batch, 2)
else:
return [self.make_noise(batch, 1)]
def train_iter(self, optimizers=None):
current_iter = self.current_iter
self.set_requires_grad(self.nets['disc'], True)
optimizers['optimD'].clear_grad()
batch = self.real_img.shape[0]
noise = self.mixing_noise(batch, self.mixing_prob)
fake_img, _ = self.nets['gen'](noise)
self.visual_items['real_img'] = self.real_img
self.visual_items['fake_img'] = fake_img
fake_pred = self.nets['disc'](fake_img.detach())
real_pred = self.nets['disc'](self.real_img)
# wgan loss with softplus (logistic loss) for discriminator
l_d_total = 0.
l_d = self.gan_criterion(real_pred, True,
is_disc=True) + self.gan_criterion(
fake_pred, False, is_disc=True)
self.losses['l_d'] = l_d
# In wgan, real_score should be positive and fake_score should be
# negative
self.losses['real_score'] = real_pred.detach().mean()
self.losses['fake_score'] = fake_pred.detach().mean()
l_d_total += l_d
if current_iter % self.disc_iters == 0:
self.real_img.stop_gradient = False
real_pred = self.nets['disc'](self.real_img)
l_d_r1 = r1_penalty(real_pred, self.real_img)
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.disc_iters +
0 * real_pred[0])
self.losses['l_d_r1'] = l_d_r1.detach().mean()
l_d_total += l_d_r1
l_d_total.backward()
optimizers['optimD'].step()
self.set_requires_grad(self.nets['disc'], False)
optimizers['optimG'].clear_grad()
noise = self.mixing_noise(batch, self.mixing_prob)
fake_img, _ = self.nets['gen'](noise)
fake_pred = self.nets['disc'](fake_img)
# wgan loss with softplus (non-saturating loss) for generator
l_g_total = 0.
l_g = self.gan_criterion(fake_pred, True, is_disc=False)
self.losses['l_g'] = l_g
l_g_total += l_g
if current_iter % self.gen_iters == 0:
path_batch_size = max(1, batch // self.path_batch_shrink)
noise = self.mixing_noise(path_batch_size, self.mixing_prob)
fake_img, latents = self.nets['gen'](noise, return_latents=True)
l_g_path, path_lengths, self.mean_path_length = g_path_regularize(
fake_img, latents, self.mean_path_length)
l_g_path = (self.path_reg_weight * self.gen_iters * l_g_path +
0 * fake_img[0, 0, 0, 0])
l_g_total += l_g_path
self.losses['l_g_path'] = l_g_path.detach().mean()
self.losses['path_length'] = path_lengths
l_g_total.backward()
optimizers['optimG'].step()
# EMA
self.model_ema(decay=0.5**(32 / (10 * 1000)))
if self.current_iter % self.visual_iters:
sample_z = [self.make_noise(1, 1)]
sample, _ = self.nets['gen_ema'](sample_z)
self.visual_items['fake_img_ema'] = sample
self.current_iter += 1
...@@ -24,21 +24,26 @@ class EqualConv2D(nn.Layer): ...@@ -24,21 +24,26 @@ class EqualConv2D(nn.Layer):
"""This convolutional layer class stabilizes the learning rate changes of its parameters. """This convolutional layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training. Equalizing learning rate keeps the weights in the network at a similar scale during training.
""" """
def __init__( def __init__(self,
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True in_channel,
): out_channel,
kernel_size,
stride=1,
padding=0,
bias=True):
super().__init__() super().__init__()
self.weight = self.create_parameter( self.weight = self.create_parameter(
(out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal() (out_channel, in_channel, kernel_size, kernel_size),
) default_initializer=nn.initializer.Normal())
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) self.scale = 1 / math.sqrt(in_channel * (kernel_size * kernel_size))
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
if bias: if bias:
self.bias = self.create_parameter((out_channel,), nn.initializer.Constant(0.0)) self.bias = self.create_parameter((out_channel, ),
nn.initializer.Constant(0.0))
else: else:
self.bias = None self.bias = None
...@@ -65,16 +70,22 @@ class EqualLinear(nn.Layer): ...@@ -65,16 +70,22 @@ class EqualLinear(nn.Layer):
"""This linear layer class stabilizes the learning rate changes of its parameters. """This linear layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training. Equalizing learning rate keeps the weights in the network at a similar scale during training.
""" """
def __init__( def __init__(self,
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None in_dim,
): out_dim,
bias=True,
bias_init=0,
lr_mul=1,
activation=None):
super().__init__() super().__init__()
self.weight = self.create_parameter((in_dim, out_dim), default_initializer=nn.initializer.Normal()) self.weight = self.create_parameter(
self.weight[:] = (self.weight / lr_mul).detach() (in_dim, out_dim), default_initializer=nn.initializer.Normal())
self.weight.set_value((self.weight / lr_mul))
if bias: if bias:
self.bias = self.create_parameter((out_dim,), nn.initializer.Constant(bias_init)) self.bias = self.create_parameter(
(out_dim, ), nn.initializer.Constant(bias_init))
else: else:
self.bias = None self.bias = None
...@@ -90,9 +101,9 @@ class EqualLinear(nn.Layer): ...@@ -90,9 +101,9 @@ class EqualLinear(nn.Layer):
out = fused_leaky_relu(out, self.bias * self.lr_mul) out = fused_leaky_relu(out, self.bias * self.lr_mul)
else: else:
out = F.linear( out = F.linear(input,
input, self.weight * self.scale, bias=self.bias * self.lr_mul self.weight * self.scale,
) bias=self.bias * self.lr_mul)
return out return out
......
...@@ -17,9 +17,8 @@ import paddle.nn as nn ...@@ -17,9 +17,8 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
def upfirdn2d_native( def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 pad_y0, pad_y1):
):
_, channel, in_h, in_w = input.shape _, channel, in_h, in_w = input.shape
input = input.reshape((-1, in_h, in_w, 1)) input = input.reshape((-1, in_h, in_w, 1))
...@@ -27,25 +26,24 @@ def upfirdn2d_native( ...@@ -27,25 +26,24 @@ def upfirdn2d_native(
kernel_h, kernel_w = kernel.shape kernel_h, kernel_w = kernel.shape
out = input.reshape((-1, in_h, 1, in_w, 1, minor)) out = input.reshape((-1, in_h, 1, in_w, 1, minor))
out = out.transpose((0,1,3,5,2,4)) out = out.transpose((0, 1, 3, 5, 2, 4))
out = out.reshape((-1,1,1,1)) out = out.reshape((-1, 1, 1, 1))
out = F.pad(out, [0, up_x - 1, 0, up_y - 1]) out = F.pad(out, [0, up_x - 1, 0, up_y - 1])
out = out.reshape((-1, in_h, in_w, minor, up_y, up_x)) out = out.reshape((-1, in_h, in_w, minor, up_y, up_x))
out = out.transpose((0,3,1,4,2,5)) out = out.transpose((0, 3, 1, 4, 2, 5))
out = out.reshape((-1, minor, in_h * up_y, in_w * up_x)) out = out.reshape((-1, minor, in_h * up_y, in_w * up_x))
out = F.pad( out = F.pad(
out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] out, [max(pad_x0, 0),
) max(pad_x1, 0),
out = out[ max(pad_y0, 0),
:,:, max(pad_y1, 0)])
max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0), out = out[:, :,
max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0), max(-pad_y0, 0):out.shape[2] - max(-pad_y1, 0),
] max(-pad_x0, 0):out.shape[3] - max(-pad_x1, 0), ]
out = out.reshape(( out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]))
))
w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w)) w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w))
out = F.conv2d(out, w) out = F.conv2d(out, w)
out = out.reshape(( out = out.reshape((
...@@ -64,9 +62,8 @@ def upfirdn2d_native( ...@@ -64,9 +62,8 @@ def upfirdn2d_native(
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
out = upfirdn2d_native( out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1],
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] pad[0], pad[1])
)
return out return out
...@@ -87,7 +84,7 @@ class Upfirdn2dUpsample(nn.Layer): ...@@ -87,7 +84,7 @@ class Upfirdn2dUpsample(nn.Layer):
super().__init__() super().__init__()
self.factor = factor self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2) kernel = make_kernel(kernel) * (factor * factor)
self.register_buffer("kernel", kernel) self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor p = kernel.shape[0] - factor
...@@ -98,7 +95,11 @@ class Upfirdn2dUpsample(nn.Layer): ...@@ -98,7 +95,11 @@ class Upfirdn2dUpsample(nn.Layer):
self.pad = (pad0, pad1) self.pad = (pad0, pad1)
def forward(self, input): def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) out = upfirdn2d(input,
self.kernel,
up=self.factor,
down=1,
pad=self.pad)
return out return out
...@@ -119,7 +120,11 @@ class Upfirdn2dDownsample(nn.Layer): ...@@ -119,7 +120,11 @@ class Upfirdn2dDownsample(nn.Layer):
self.pad = (pad0, pad1) self.pad = (pad0, pad1)
def forward(self, input): def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) out = upfirdn2d(input,
self.kernel,
up=1,
down=self.factor,
pad=self.pad)
return out return out
...@@ -131,9 +136,9 @@ class Upfirdn2dBlur(nn.Layer): ...@@ -131,9 +136,9 @@ class Upfirdn2dBlur(nn.Layer):
kernel = make_kernel(kernel) kernel = make_kernel(kernel)
if upsample_factor > 1: if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2) kernel = kernel * (upsample_factor * upsample_factor)
self.register_buffer("kernel", kernel) self.register_buffer("kernel", kernel, persistable=False)
self.pad = pad self.pad = pad
......
import librosa
import librosa.filters
import numpy as np import numpy as np
from scipy import signal from scipy import signal
from scipy.io import wavfile from scipy.io import wavfile
from paddle.utils import try_import
from .audio_config import get_audio_config from .audio_config import get_audio_config
audio_config = get_audio_config() audio_config = get_audio_config()
def load_wav(path, sr): def load_wav(path, sr):
librosa = try_import('librosa')
return librosa.core.load(path, sr=sr)[0] return librosa.core.load(path, sr=sr)[0]
...@@ -19,6 +19,7 @@ def save_wav(wav, path, sr): ...@@ -19,6 +19,7 @@ def save_wav(wav, path, sr):
def save_wavenet_wav(wav, path, sr): def save_wavenet_wav(wav, path, sr):
librosa = try_import('librosa')
librosa.output.write_wav(path, wav, sr=sr) librosa.output.write_wav(path, wav, sr=sr)
...@@ -75,6 +76,7 @@ def _stft(y): ...@@ -75,6 +76,7 @@ def _stft(y):
if audio_config.use_lws: if audio_config.use_lws:
return _lws_processor(audio_config).stft(y).T return _lws_processor(audio_config).stft(y).T
else: else:
librosa = try_import('librosa')
return librosa.stft(y=y, return librosa.stft(y=y,
n_fft=audio_config.n_fft, n_fft=audio_config.n_fft,
hop_length=get_hop_size(), hop_length=get_hop_size(),
...@@ -123,6 +125,7 @@ def _linear_to_mel(spectogram): ...@@ -123,6 +125,7 @@ def _linear_to_mel(spectogram):
def _build_mel_basis(): def _build_mel_basis():
assert audio_config.fmax <= audio_config.sample_rate // 2 assert audio_config.fmax <= audio_config.sample_rate // 2
librosa = try_import('librosa')
return librosa.filters.mel(audio_config.sample_rate, return librosa.filters.mel(audio_config.sample_rate,
audio_config.n_fft, audio_config.n_fft,
n_mels=audio_config.num_mels, n_mels=audio_config.num_mels,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册