未验证 提交 530a6a8c 编写于 作者: L LielinJiang 提交者: GitHub

Add train code of stylegan2 (#149)

* add stylegan model
上级 e13e1c18
total_iters: 800000
output_dir: output_dir
model:
name: StyleGAN2Model
generator:
name: StyleGANv2Generator
size: 256
style_dim: 512
n_mlp: 8
discriminator:
name: StyleGANv2Discriminator
size: 256
gan_criterion:
name: GANLoss
gan_mode: logistic
loss_weight: !!float 1
# r1 regularization for discriminator
r1_reg_weight: 10.
# path length regularization for generator
path_batch_shrink: 2.
path_reg_weight: 2.
params:
gen_iters: 4
disc_iters: 16
dataset:
train:
name: SingleDataset
dataroot: data/ffhq/images256x256/
num_workers: 3
batch_size: 3
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: RandomHorizontalFlip
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.002
milestones: [600000]
gamma: 0.5
optimizer:
optimG:
name: Adam
beta1: 0.0
beta2: 0.792
net_names:
- gen
optimD:
name: Adam
net_names:
- disc
beta1: 0.0
beta2: 0.9317647058823529
log_config:
interval: 50
visiual_interval: 500
snapshot_config:
interval: 5000
...@@ -59,19 +59,21 @@ class Transforms(): ...@@ -59,19 +59,21 @@ class Transforms():
data = tuple(data) data = tuple(data)
for transform in self.transforms: for transform in self.transforms:
data = transform(data) data = transform(data)
if hasattr(transform, 'params') and isinstance( if hasattr(transform, 'params') and isinstance(
transform.params, dict): transform.params, dict):
datas.update(transform.params) datas.update(transform.params)
if len(self.input_keys) > 1:
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
else:
datas[k] = data
if self.output_keys is not None: if self.output_keys is not None:
for i, k in enumerate(self.output_keys): for i, k in enumerate(self.output_keys):
datas[k] = data[i] datas[k] = data[i]
return datas return datas
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
return datas return datas
......
...@@ -27,7 +27,7 @@ class SingleDataset(BaseDataset): ...@@ -27,7 +27,7 @@ class SingleDataset(BaseDataset):
dataroot (str): Directory of dataset. dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config. preprocess (list[dict]): A sequence of data preprocess config.
""" """
super(SingleDataset).__init__(self, preprocess) super(SingleDataset, self).__init__(preprocess)
self.dataroot = dataroot self.dataroot = dataroot
self.data_infos = self.prepare_data_infos() self.data_infos = self.prepare_data_infos()
......
...@@ -123,6 +123,8 @@ class Trainer: ...@@ -123,6 +123,8 @@ class Trainer:
self.batch_id = 0 self.batch_id = 0
self.global_steps = 0 self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval self.weight_interval = cfg.snapshot_config.interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.log_interval = cfg.log_config.interval self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval self.visual_interval = cfg.log_config.visiual_interval
if self.by_epoch: if self.by_epoch:
...@@ -143,6 +145,17 @@ class Trainer: ...@@ -143,6 +145,17 @@ class Trainer:
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
self.model.nets[net_name] = paddle.DataParallel(net, strategy) self.model.nets[net_name] = paddle.DataParallel(net, strategy)
def learning_rate_scheduler_step(self):
if isinstance(self.model.lr_scheduler, dict):
for lr_scheduler in self.model.lr_scheduler.values():
lr_scheduler.step()
elif isinstance(self.model.lr_scheduler,
paddle.optimizer.lr.LRScheduler):
self.model.lr_scheduler.step()
else:
raise ValueError(
'lr schedulter must be a dict or an instance of LRScheduler')
def train(self): def train(self):
reader_cost_averager = TimeAverager() reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager() batch_cost_averager = TimeAverager()
...@@ -179,7 +192,7 @@ class Trainer: ...@@ -179,7 +192,7 @@ class Trainer:
if self.current_iter % self.visual_interval == 0: if self.current_iter % self.visual_interval == 0:
self.visual('visual_train') self.visual('visual_train')
self.model.lr_scheduler.step() self.learning_rate_scheduler_step()
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
self.test() self.test()
......
...@@ -22,5 +22,6 @@ from .esrgan_model import ESRGAN ...@@ -22,5 +22,6 @@ from .esrgan_model import ESRGAN
from .ugatit_model import UGATITModel from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel from .dc_gan_model import DCGANModel
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
from .styleganv2_model import StyleGAN2Model
from .wav2lip_model import Wav2LipModel from .wav2lip_model import Wav2LipModel
from .wav2lip_hq_model import Wav2LipModelHq from .wav2lip_hq_model import Wav2LipModelHq
...@@ -35,22 +35,22 @@ class ConvLayer(nn.Sequential): ...@@ -35,22 +35,22 @@ class ConvLayer(nn.Sequential):
activate=True, activate=True,
): ):
layers = [] layers = []
if downsample: if downsample:
factor = 2 factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1) p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2 pad0 = (p + 1) // 2
pad1 = p // 2 pad1 = p // 2
layers.append(Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))) layers.append(Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)))
stride = 2 stride = 2
self.padding = 0 self.padding = 0
else: else:
stride = 1 stride = 1
self.padding = kernel_size // 2 self.padding = kernel_size // 2
layers.append( layers.append(
EqualConv2D( EqualConv2D(
in_channel, in_channel,
...@@ -59,41 +59,58 @@ class ConvLayer(nn.Sequential): ...@@ -59,41 +59,58 @@ class ConvLayer(nn.Sequential):
padding=self.padding, padding=self.padding,
stride=stride, stride=stride,
bias=bias and not activate, bias=bias and not activate,
) ))
)
if activate: if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias)) layers.append(FusedLeakyReLU(out_channel, bias=bias))
super().__init__(*layers) super().__init__(*layers)
class ResBlock(nn.Layer): class ResBlock(nn.Layer):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__() super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer( self.skip = ConvLayer(in_channel,
in_channel, out_channel, 1, downsample=True, activate=False, bias=False out_channel,
) 1,
downsample=True,
activate=False,
bias=False)
def forward(self, input): def forward(self, input):
out = self.conv1(input) out = self.conv1(input)
out = self.conv2(out) out = self.conv2(out)
skip = self.skip(input) skip = self.skip(input)
out = (out + skip) / math.sqrt(2) out = (out + skip) / math.sqrt(2)
return out return out
# temporally solve pow double grad problem
def var(x, axis=None, unbiased=True, keepdim=False, name=None):
u = paddle.mean(x, axis, True, name)
out = paddle.sum((x - u) * (x - u), axis, keepdim=keepdim, name=name)
n = paddle.cast(paddle.numel(x), x.dtype) \
/ paddle.cast(paddle.numel(out), x.dtype)
if unbiased:
one_const = paddle.ones([1], x.dtype)
n = paddle.where(n > one_const, n - 1., one_const)
out /= n
return out
@DISCRIMINATORS.register() @DISCRIMINATORS.register()
class StyleGANv2Discriminator(nn.Layer): class StyleGANv2Discriminator(nn.Layer):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__() super().__init__()
channels = { channels = {
4: 512, 4: 512,
8: 512, 8: 512,
...@@ -105,47 +122,48 @@ class StyleGANv2Discriminator(nn.Layer): ...@@ -105,47 +122,48 @@ class StyleGANv2Discriminator(nn.Layer):
512: 32 * channel_multiplier, 512: 32 * channel_multiplier,
1024: 16 * channel_multiplier, 1024: 16 * channel_multiplier,
} }
convs = [ConvLayer(3, channels[size], 1)] convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2)) log_size = int(math.log(size, 2))
in_channel = channels[size] in_channel = channels[size]
for i in range(log_size, 2, -1): for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)] out_channel = channels[2**(i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel)) convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel in_channel = out_channel
self.convs = nn.Sequential(*convs) self.convs = nn.Sequential(*convs)
self.stddev_group = 4 self.stddev_group = 4
self.stddev_feat = 1 self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential( self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), EqualLinear(channels[4] * 4 * 4,
channels[4],
activation="fused_lrelu"),
EqualLinear(channels[4], 1), EqualLinear(channels[4], 1),
) )
def forward(self, input): def forward(self, input):
out = self.convs(input) out = self.convs(input)
batch, channel, height, width = out.shape batch, channel, height, width = out.shape
group = min(batch, self.stddev_group) group = min(batch, self.stddev_group)
stddev = out.reshape(( stddev = out.reshape((group, -1, self.stddev_feat,
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width channel // self.stddev_feat, height, width))
)) stddev = paddle.sqrt(var(stddev, 0, unbiased=False) + 1e-8)
stddev = paddle.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2) stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
stddev = stddev.tile((group, 1, height, width)) stddev = stddev.tile((group, 1, height, width))
out = paddle.concat([out, stddev], 1) out = paddle.concat([out, stddev], 1)
out = self.final_conv(out) out = self.final_conv(out)
out = out.reshape((batch, -1)) out = out.reshape((batch, -1))
out = self.final_linear(out) out = self.final_linear(out)
return out return out
...@@ -27,11 +27,12 @@ from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur ...@@ -27,11 +27,12 @@ from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur
class PixelNorm(nn.Layer): class PixelNorm(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, input): def forward(self, input):
return input * paddle.rsqrt(paddle.mean(input ** 2, 1, keepdim=True) + 1e-8) return input * paddle.rsqrt(
paddle.mean(input * input, 1, keepdim=True) + 1e-8)
class ModulatedConv2D(nn.Layer): class ModulatedConv2D(nn.Layer):
def __init__( def __init__(
self, self,
...@@ -45,75 +46,78 @@ class ModulatedConv2D(nn.Layer): ...@@ -45,75 +46,78 @@ class ModulatedConv2D(nn.Layer):
blur_kernel=[1, 3, 3, 1], blur_kernel=[1, 3, 3, 1],
): ):
super().__init__() super().__init__()
self.eps = 1e-8 self.eps = 1e-8
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.in_channel = in_channel self.in_channel = in_channel
self.out_channel = out_channel self.out_channel = out_channel
self.upsample = upsample self.upsample = upsample
self.downsample = downsample self.downsample = downsample
if upsample: if upsample:
factor = 2 factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1) p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1 pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1 pad1 = p // 2 + 1
self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) self.blur = Upfirdn2dBlur(blur_kernel,
pad=(pad0, pad1),
upsample_factor=factor)
if downsample: if downsample:
factor = 2 factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1) p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2 pad0 = (p + 1) // 2
pad1 = p // 2 pad1 = p // 2
self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)) self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2 fan_in = in_channel * (kernel_size * kernel_size)
self.scale = 1 / math.sqrt(fan_in) self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2 self.padding = kernel_size // 2
self.weight = self.create_parameter( self.weight = self.create_parameter(
(1, out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal() (1, out_channel, in_channel, kernel_size, kernel_size),
) default_initializer=nn.initializer.Normal())
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate self.demodulate = demodulate
def __repr__(self): def __repr__(self):
return ( return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})" f"upsample={self.upsample}, downsample={self.downsample})")
)
def forward(self, input, style): def forward(self, input, style):
batch, in_channel, height, width = input.shape batch, in_channel, height, width = input.shape
style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1)) style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
weight = self.scale * self.weight * style weight = self.scale * self.weight * style
if self.demodulate: if self.demodulate:
demod = paddle.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1)) weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
weight = weight.reshape(( weight = weight.reshape((batch * self.out_channel, in_channel,
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size self.kernel_size, self.kernel_size))
))
if self.upsample: if self.upsample:
input = input.reshape((1, batch * in_channel, height, width)) input = input.reshape((1, batch * in_channel, height, width))
weight = weight.reshape(( weight = weight.reshape((batch, self.out_channel, in_channel,
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size self.kernel_size, self.kernel_size))
)) weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
weight = weight.transpose((0, 2, 1, 3, 4)).reshape(( (batch * in_channel, self.out_channel, self.kernel_size,
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size self.kernel_size))
)) out = F.conv2d_transpose(input,
out = F.conv2d_transpose(input, weight, padding=0, stride=2, groups=batch) weight,
padding=0,
stride=2,
groups=batch)
_, _, height, width = out.shape _, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width)) out = out.reshape((batch, self.out_channel, height, width))
out = self.blur(out) out = self.blur(out)
elif self.downsample: elif self.downsample:
input = self.blur(input) input = self.blur(input)
_, _, height, width = input.shape _, _, height, width = input.shape
...@@ -121,43 +125,46 @@ class ModulatedConv2D(nn.Layer): ...@@ -121,43 +125,46 @@ class ModulatedConv2D(nn.Layer):
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape _, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width)) out = out.reshape((batch, self.out_channel, height, width))
else: else:
input = input.reshape((1, batch * in_channel, height, width)) input = input.reshape((1, batch * in_channel, height, width))
out = F.conv2d(input, weight, padding=self.padding, groups=batch) out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape _, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width)) out = out.reshape((batch, self.out_channel, height, width))
return out return out
class NoiseInjection(nn.Layer): class NoiseInjection(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.weight = self.create_parameter((1,), default_initializer=nn.initializer.Constant(0.0)) self.weight = self.create_parameter(
(1, ), default_initializer=nn.initializer.Constant(0.0))
def forward(self, image, noise=None): def forward(self, image, noise=None):
if noise is None: if noise is None:
batch, _, height, width = image.shape batch, _, height, width = image.shape
noise = paddle.randn((batch, 1, height, width)) noise = paddle.randn((batch, 1, height, width))
return image + self.weight * noise return image + self.weight * noise
class ConstantInput(nn.Layer): class ConstantInput(nn.Layer):
def __init__(self, channel, size=4): def __init__(self, channel, size=4):
super().__init__() super().__init__()
self.input = self.create_parameter((1, channel, size, size), default_initializer=nn.initializer.Normal()) self.input = self.create_parameter(
(1, channel, size, size),
default_initializer=nn.initializer.Normal())
def forward(self, input): def forward(self, input):
batch = input.shape[0] batch = input.shape[0]
out = self.input.tile((batch, 1, 1, 1)) out = self.input.tile((batch, 1, 1, 1))
return out return out
class StyledConv(nn.Layer): class StyledConv(nn.Layer):
def __init__( def __init__(
self, self,
...@@ -170,7 +177,7 @@ class StyledConv(nn.Layer): ...@@ -170,7 +177,7 @@ class StyledConv(nn.Layer):
demodulate=True, demodulate=True,
): ):
super().__init__() super().__init__()
self.conv = ModulatedConv2D( self.conv = ModulatedConv2D(
in_channel, in_channel,
out_channel, out_channel,
...@@ -180,40 +187,49 @@ class StyledConv(nn.Layer): ...@@ -180,40 +187,49 @@ class StyledConv(nn.Layer):
blur_kernel=blur_kernel, blur_kernel=blur_kernel,
demodulate=demodulate, demodulate=demodulate,
) )
self.noise = NoiseInjection() self.noise = NoiseInjection()
self.activate = FusedLeakyReLU(out_channel) self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style, noise=None): def forward(self, input, style, noise=None):
out = self.conv(input, style) out = self.conv(input, style)
out = self.noise(out, noise=noise) out = self.noise(out, noise=noise)
out = self.activate(out) out = self.activate(out)
return out return out
class ToRGB(nn.Layer): class ToRGB(nn.Layer):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): def __init__(self,
in_channel,
style_dim,
upsample=True,
blur_kernel=[1, 3, 3, 1]):
super().__init__() super().__init__()
if upsample: if upsample:
self.upsample = Upfirdn2dUpsample(blur_kernel) self.upsample = Upfirdn2dUpsample(blur_kernel)
self.conv = ModulatedConv2D(in_channel, 3, 1, style_dim, demodulate=False) self.conv = ModulatedConv2D(in_channel,
self.bias = self.create_parameter((1, 3, 1, 1), nn.initializer.Constant(0.0)) 3,
1,
style_dim,
demodulate=False)
self.bias = self.create_parameter((1, 3, 1, 1),
nn.initializer.Constant(0.0))
def forward(self, input, style, skip=None): def forward(self, input, style, skip=None):
out = self.conv(input, style) out = self.conv(input, style)
out = out + self.bias out = out + self.bias
if skip is not None: if skip is not None:
skip = self.upsample(skip) skip = self.upsample(skip)
out = out + skip out = out + skip
return out return out
@GENERATORS.register() @GENERATORS.register()
class StyleGANv2Generator(nn.Layer): class StyleGANv2Generator(nn.Layer):
def __init__( def __init__(
...@@ -226,22 +242,22 @@ class StyleGANv2Generator(nn.Layer): ...@@ -226,22 +242,22 @@ class StyleGANv2Generator(nn.Layer):
lr_mlp=0.01, lr_mlp=0.01,
): ):
super().__init__() super().__init__()
self.size = size self.size = size
self.style_dim = style_dim self.style_dim = style_dim
layers = [PixelNorm()] layers = [PixelNorm()]
for i in range(n_mlp): for i in range(n_mlp):
layers.append( layers.append(
EqualLinear( EqualLinear(style_dim,
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" style_dim,
) lr_mul=lr_mlp,
) activation="fused_lrelu"))
self.style = nn.Sequential(*layers) self.style = nn.Sequential(*layers)
self.channels = { self.channels = {
4: 512, 4: 512,
8: 512, 8: 512,
...@@ -253,31 +269,34 @@ class StyleGANv2Generator(nn.Layer): ...@@ -253,31 +269,34 @@ class StyleGANv2Generator(nn.Layer):
512: 32 * channel_multiplier, 512: 32 * channel_multiplier,
1024: 16 * channel_multiplier, 1024: 16 * channel_multiplier,
} }
self.input = ConstantInput(self.channels[4]) self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv( self.conv1 = StyledConv(self.channels[4],
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel self.channels[4],
) 3,
style_dim,
blur_kernel=blur_kernel)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2)) self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1 self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.LayerList() self.convs = nn.LayerList()
self.upsamples = nn.LayerList() self.upsamples = nn.LayerList()
self.to_rgbs = nn.LayerList() self.to_rgbs = nn.LayerList()
self.noises = nn.Layer() self.noises = nn.Layer()
in_channel = self.channels[4] in_channel = self.channels[4]
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2 res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res] shape = [1, 1, 2**res, 2**res]
self.noises.register_buffer(f"noise_{layer_idx}", paddle.randn(shape)) self.noises.register_buffer(f"noise_{layer_idx}",
paddle.randn(shape))
for i in range(3, self.log_size + 1): for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i] out_channel = self.channels[2**i]
self.convs.append( self.convs.append(
StyledConv( StyledConv(
in_channel, in_channel,
...@@ -286,41 +305,39 @@ class StyleGANv2Generator(nn.Layer): ...@@ -286,41 +305,39 @@ class StyleGANv2Generator(nn.Layer):
style_dim, style_dim,
upsample=True, upsample=True,
blur_kernel=blur_kernel, blur_kernel=blur_kernel,
) ))
)
self.convs.append( self.convs.append(
StyledConv( StyledConv(out_channel,
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel out_channel,
) 3,
) style_dim,
blur_kernel=blur_kernel))
self.to_rgbs.append(ToRGB(out_channel, style_dim)) self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel in_channel = out_channel
self.n_latent = self.log_size * 2 - 2 self.n_latent = self.log_size * 2 - 2
def make_noise(self): def make_noise(self):
noises = [paddle.randn((1, 1, 2 ** 2, 2 ** 2))] noises = [paddle.randn((1, 1, 2**2, 2**2))]
for i in range(3, self.log_size + 1): for i in range(3, self.log_size + 1):
for _ in range(2): for _ in range(2):
noises.append(paddle.randn((1, 1, 2 ** i, 2 ** i))) noises.append(paddle.randn((1, 1, 2**i, 2**i)))
return noises return noises
def mean_latent(self, n_latent): def mean_latent(self, n_latent):
latent_in = paddle.randn(( latent_in = paddle.randn((n_latent, self.style_dim))
n_latent, self.style_dim
))
latent = self.style(latent_in).mean(0, keepdim=True) latent = self.style(latent_in).mean(0, keepdim=True)
return latent return latent
def get_latent(self, input): def get_latent(self, input):
return self.style(input) return self.style(input)
def forward( def forward(
self, self,
styles, styles,
...@@ -334,62 +351,65 @@ class StyleGANv2Generator(nn.Layer): ...@@ -334,62 +351,65 @@ class StyleGANv2Generator(nn.Layer):
): ):
if not input_is_latent: if not input_is_latent:
styles = [self.style(s) for s in styles] styles = [self.style(s) for s in styles]
if noise is None: if noise is None:
if randomize_noise: if randomize_noise:
noise = [None] * self.num_layers noise = [None] * self.num_layers
else: else:
noise = [ noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) getattr(self.noises, f"noise_{i}")
for i in range(self.num_layers)
] ]
if truncation < 1: if truncation < 1:
style_t = [] style_t = []
for style in styles: for style in styles:
style_t.append( style_t.append(truncation_latent + truncation *
truncation_latent + truncation * (style - truncation_latent) (style - truncation_latent))
)
styles = style_t styles = style_t
if len(styles) < 2: if len(styles) < 2:
inject_index = self.n_latent inject_index = self.n_latent
if styles[0].ndim < 3: if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).tile((1, inject_index, 1)) latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
else: else:
latent = styles[0] latent = styles[0]
else: else:
if inject_index is None: if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1) inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).tile((1, inject_index, 1)) latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
latent2 = styles[1].unsqueeze(1).tile((1, self.n_latent - inject_index, 1)) latent2 = styles[1].unsqueeze(1).tile(
(1, self.n_latent - inject_index, 1))
latent = paddle.concat([latent, latent2], 1) latent = paddle.concat([latent, latent2], 1)
out = self.input(latent) out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0]) out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1]) skip = self.to_rgb1(out, latent[:, 1])
i = 1 i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip( for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2],
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs self.convs[1::2],
): noise[1::2],
noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1) out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2) out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) skip = to_rgb(out, latent[:, i + 2], skip)
i += 2 i += 2
image = skip image = skip
if return_latents: if return_latents:
return image, latent return image, latent
else: else:
return image, None return image, None
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import paddle
import paddle.nn as nn
from .base_model import BaseModel
from .builder import MODELS
from .criterions import build_criterion
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer
def r1_penalty(real_pred, real_img):
"""
R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = paddle.grad(outputs=real_pred.sum(),
inputs=real_img,
create_graph=True)[0]
grad_penalty = (grad_real * grad_real).reshape([grad_real.shape[0],
-1]).sum(1).mean()
return grad_penalty
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = paddle.randn(fake_img.shape) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3])
grad = paddle.grad(outputs=(fake_img * noise).sum(),
inputs=latents,
create_graph=True)[0]
path_lengths = paddle.sqrt((grad * grad).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() -
mean_path_length)
path_penalty = ((path_lengths - path_mean) *
(path_lengths - path_mean)).mean()
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
@MODELS.register()
class StyleGAN2Model(BaseModel):
"""
This class implements the StyleGANV2 model, for learning image-to-image translation without paired data.
StyleGAN2 paper: https://arxiv.org/pdf/1912.04958.pdf
"""
def __init__(self,
generator,
discriminator=None,
gan_criterion=None,
num_style_feat=512,
mixing_prob=0.9,
r1_reg_weight=10.,
path_reg_weight=2.,
path_batch_shrink=2.,
params=None):
"""Initialize the CycleGAN class.
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
gan_criterion (dict): config of gan criterion.
"""
super(StyleGAN2Model, self).__init__(params)
self.gen_iters = 4 if self.params is None else self.params.get(
'gen_iters', 4)
self.disc_iters = 16 if self.params is None else self.params.get(
'disc_iters', 16)
self.disc_start_iters = (0 if self.params is None else self.params.get(
'disc_start_iters', 0))
self.visual_iters = (500 if self.params is None else self.params.get(
'visual_iters', 500))
self.mixing_prob = mixing_prob
self.num_style_feat = num_style_feat
self.r1_reg_weight = r1_reg_weight
self.path_reg_weight = path_reg_weight
self.path_batch_shrink = path_batch_shrink
self.mean_path_length = 0
self.nets['gen'] = build_generator(generator)
# define discriminators
if discriminator:
self.nets['disc'] = build_discriminator(discriminator)
self.nets['gen_ema'] = build_generator(generator)
self.model_ema(0)
self.nets['gen'].train()
self.nets['gen_ema'].eval()
self.nets['disc'].train()
self.current_iter = 1
# define loss functions
if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)
def setup_lr_schedulers(self, cfg):
self.lr_scheduler = dict()
gen_cfg = cfg.copy()
net_g_reg_ratio = self.gen_iters / (self.gen_iters + 1)
gen_cfg['learning_rate'] = cfg['learning_rate'] * net_g_reg_ratio
self.lr_scheduler['gen'] = build_lr_scheduler(gen_cfg)
disc_cfg = cfg.copy()
net_d_reg_ratio = self.disc_iters / (self.disc_iters + 1)
disc_cfg['learning_rate'] = cfg['learning_rate'] * net_d_reg_ratio
self.lr_scheduler['disc'] = build_lr_scheduler(disc_cfg)
return self.lr_scheduler
def setup_optimizers(self, lr, cfg):
for opt_name, opt_cfg in cfg.items():
if opt_name == 'optimG':
_lr = lr['gen']
elif opt_name == 'optimD':
_lr = lr['disc']
else:
raise ValueError("opt name must be in ['optimG', optimD]")
cfg_ = opt_cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
self.optimizers[opt_name] = build_optimizer(cfg_, _lr, parameters)
return self.optimizers
def get_bare_model(self, net):
"""Get bare model, especially under wrapping with DataParallel.
"""
if isinstance(net, (paddle.DataParallel)):
net = net._layers
return net
def model_ema(self, decay=0.999):
net_g = self.get_bare_model(self.nets['gen'])
net_g_params = dict(net_g.named_parameters())
neg_g_ema = self.get_bare_model(self.nets['gen_ema'])
net_g_ema_params = dict(neg_g_ema.named_parameters())
for k in net_g_ema_params.keys():
net_g_ema_params[k].set_value(net_g_ema_params[k] * (decay) +
(net_g_params[k] * (1 - decay)))
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Args:
input (dict): include the data itself and its metadata information.
"""
self.real_img = paddle.fluid.dygraph.to_variable(input['A'])
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
def make_noise(self, batch, num_noise):
if num_noise == 1:
noises = paddle.randn([batch, self.num_style_feat])
else:
noises = []
for _ in range(num_noise):
noises.append(paddle.randn([batch, self.num_style_feat]))
return noises
def mixing_noise(self, batch, prob):
if random.random() < prob:
return self.make_noise(batch, 2)
else:
return [self.make_noise(batch, 1)]
def train_iter(self, optimizers=None):
current_iter = self.current_iter
self.set_requires_grad(self.nets['disc'], True)
optimizers['optimD'].clear_grad()
batch = self.real_img.shape[0]
noise = self.mixing_noise(batch, self.mixing_prob)
fake_img, _ = self.nets['gen'](noise)
self.visual_items['real_img'] = self.real_img
self.visual_items['fake_img'] = fake_img
fake_pred = self.nets['disc'](fake_img.detach())
real_pred = self.nets['disc'](self.real_img)
# wgan loss with softplus (logistic loss) for discriminator
l_d_total = 0.
l_d = self.gan_criterion(real_pred, True,
is_disc=True) + self.gan_criterion(
fake_pred, False, is_disc=True)
self.losses['l_d'] = l_d
# In wgan, real_score should be positive and fake_score should be
# negative
self.losses['real_score'] = real_pred.detach().mean()
self.losses['fake_score'] = fake_pred.detach().mean()
l_d_total += l_d
if current_iter % self.disc_iters == 0:
self.real_img.stop_gradient = False
real_pred = self.nets['disc'](self.real_img)
l_d_r1 = r1_penalty(real_pred, self.real_img)
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.disc_iters +
0 * real_pred[0])
self.losses['l_d_r1'] = l_d_r1.detach().mean()
l_d_total += l_d_r1
l_d_total.backward()
optimizers['optimD'].step()
self.set_requires_grad(self.nets['disc'], False)
optimizers['optimG'].clear_grad()
noise = self.mixing_noise(batch, self.mixing_prob)
fake_img, _ = self.nets['gen'](noise)
fake_pred = self.nets['disc'](fake_img)
# wgan loss with softplus (non-saturating loss) for generator
l_g_total = 0.
l_g = self.gan_criterion(fake_pred, True, is_disc=False)
self.losses['l_g'] = l_g
l_g_total += l_g
if current_iter % self.gen_iters == 0:
path_batch_size = max(1, batch // self.path_batch_shrink)
noise = self.mixing_noise(path_batch_size, self.mixing_prob)
fake_img, latents = self.nets['gen'](noise, return_latents=True)
l_g_path, path_lengths, self.mean_path_length = g_path_regularize(
fake_img, latents, self.mean_path_length)
l_g_path = (self.path_reg_weight * self.gen_iters * l_g_path +
0 * fake_img[0, 0, 0, 0])
l_g_total += l_g_path
self.losses['l_g_path'] = l_g_path.detach().mean()
self.losses['path_length'] = path_lengths
l_g_total.backward()
optimizers['optimG'].step()
# EMA
self.model_ema(decay=0.5**(32 / (10 * 1000)))
if self.current_iter % self.visual_iters:
sample_z = [self.make_noise(1, 1)]
sample, _ = self.nets['gen_ema'](sample_z)
self.visual_items['fake_img_ema'] = sample
self.current_iter += 1
...@@ -24,25 +24,30 @@ class EqualConv2D(nn.Layer): ...@@ -24,25 +24,30 @@ class EqualConv2D(nn.Layer):
"""This convolutional layer class stabilizes the learning rate changes of its parameters. """This convolutional layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training. Equalizing learning rate keeps the weights in the network at a similar scale during training.
""" """
def __init__( def __init__(self,
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True in_channel,
): out_channel,
kernel_size,
stride=1,
padding=0,
bias=True):
super().__init__() super().__init__()
self.weight = self.create_parameter( self.weight = self.create_parameter(
(out_channel, in_channel, kernel_size, kernel_size), default_initializer=nn.initializer.Normal() (out_channel, in_channel, kernel_size, kernel_size),
) default_initializer=nn.initializer.Normal())
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) self.scale = 1 / math.sqrt(in_channel * (kernel_size * kernel_size))
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
if bias: if bias:
self.bias = self.create_parameter((out_channel,), nn.initializer.Constant(0.0)) self.bias = self.create_parameter((out_channel, ),
nn.initializer.Constant(0.0))
else: else:
self.bias = None self.bias = None
def forward(self, input): def forward(self, input):
out = F.conv2d( out = F.conv2d(
input, input,
...@@ -51,51 +56,57 @@ class EqualConv2D(nn.Layer): ...@@ -51,51 +56,57 @@ class EqualConv2D(nn.Layer):
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
) )
return out return out
def __repr__(self): def __repr__(self):
return ( return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
) )
class EqualLinear(nn.Layer): class EqualLinear(nn.Layer):
"""This linear layer class stabilizes the learning rate changes of its parameters. """This linear layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training. Equalizing learning rate keeps the weights in the network at a similar scale during training.
""" """
def __init__( def __init__(self,
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None in_dim,
): out_dim,
bias=True,
bias_init=0,
lr_mul=1,
activation=None):
super().__init__() super().__init__()
self.weight = self.create_parameter((in_dim, out_dim), default_initializer=nn.initializer.Normal()) self.weight = self.create_parameter(
self.weight[:] = (self.weight / lr_mul).detach() (in_dim, out_dim), default_initializer=nn.initializer.Normal())
self.weight.set_value((self.weight / lr_mul))
if bias: if bias:
self.bias = self.create_parameter((out_dim,), nn.initializer.Constant(bias_init)) self.bias = self.create_parameter(
(out_dim, ), nn.initializer.Constant(bias_init))
else: else:
self.bias = None self.bias = None
self.activation = activation self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul self.lr_mul = lr_mul
def forward(self, input): def forward(self, input):
if self.activation: if self.activation:
out = F.linear(input, self.weight * self.scale) out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul) out = fused_leaky_relu(out, self.bias * self.lr_mul)
else: else:
out = F.linear( out = F.linear(input,
input, self.weight * self.scale, bias=self.bias * self.lr_mul self.weight * self.scale,
) bias=self.bias * self.lr_mul)
return out return out
def __repr__(self): def __repr__(self):
return ( return (
f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})" f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]})"
......
...@@ -15,37 +15,35 @@ ...@@ -15,37 +15,35 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
def upfirdn2d_native( def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 pad_y0, pad_y1):
):
_, channel, in_h, in_w = input.shape _, channel, in_h, in_w = input.shape
input = input.reshape((-1, in_h, in_w, 1)) input = input.reshape((-1, in_h, in_w, 1))
_, in_h, in_w, minor = input.shape _, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape kernel_h, kernel_w = kernel.shape
out = input.reshape((-1, in_h, 1, in_w, 1, minor)) out = input.reshape((-1, in_h, 1, in_w, 1, minor))
out = out.transpose((0,1,3,5,2,4)) out = out.transpose((0, 1, 3, 5, 2, 4))
out = out.reshape((-1,1,1,1)) out = out.reshape((-1, 1, 1, 1))
out = F.pad(out, [0, up_x - 1, 0, up_y - 1]) out = F.pad(out, [0, up_x - 1, 0, up_y - 1])
out = out.reshape((-1, in_h, in_w, minor, up_y, up_x)) out = out.reshape((-1, in_h, in_w, minor, up_y, up_x))
out = out.transpose((0,3,1,4,2,5)) out = out.transpose((0, 3, 1, 4, 2, 5))
out = out.reshape((-1, minor, in_h * up_y, in_w * up_x)) out = out.reshape((-1, minor, in_h * up_y, in_w * up_x))
out = F.pad( out = F.pad(
out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] out, [max(pad_x0, 0),
) max(pad_x1, 0),
out = out[ max(pad_y0, 0),
:,:, max(pad_y1, 0)])
max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0), out = out[:, :,
max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0), max(-pad_y0, 0):out.shape[2] - max(-pad_y1, 0),
] max(-pad_x0, 0):out.shape[3] - max(-pad_x1, 0), ]
out = out.reshape(( out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]))
))
w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w)) w = paddle.flip(kernel, [0, 1]).reshape((1, 1, kernel_h, kernel_w))
out = F.conv2d(out, w) out = F.conv2d(out, w)
out = out.reshape(( out = out.reshape((
...@@ -56,88 +54,95 @@ def upfirdn2d_native( ...@@ -56,88 +54,95 @@ def upfirdn2d_native(
)) ))
out = out.transpose((0, 2, 3, 1)) out = out.transpose((0, 2, 3, 1))
out = out[:, ::down_y, ::down_x, :] out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.reshape((-1, channel, out_h, out_w)) return out.reshape((-1, channel, out_h, out_w))
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
out = upfirdn2d_native( out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1],
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] pad[0], pad[1])
)
return out return out
def make_kernel(k): def make_kernel(k):
k = paddle.to_tensor(k, dtype='float32') k = paddle.to_tensor(k, dtype='float32')
if k.ndim == 1: if k.ndim == 1:
k = k.unsqueeze(0) * k.unsqueeze(1) k = k.unsqueeze(0) * k.unsqueeze(1)
k /= k.sum() k /= k.sum()
return k return k
class Upfirdn2dUpsample(nn.Layer): class Upfirdn2dUpsample(nn.Layer):
def __init__(self, kernel, factor=2): def __init__(self, kernel, factor=2):
super().__init__() super().__init__()
self.factor = factor self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2) kernel = make_kernel(kernel) * (factor * factor)
self.register_buffer("kernel", kernel) self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1 pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 pad1 = p // 2
self.pad = (pad0, pad1) self.pad = (pad0, pad1)
def forward(self, input): def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) out = upfirdn2d(input,
self.kernel,
up=self.factor,
down=1,
pad=self.pad)
return out return out
class Upfirdn2dDownsample(nn.Layer): class Upfirdn2dDownsample(nn.Layer):
def __init__(self, kernel, factor=2): def __init__(self, kernel, factor=2):
super().__init__() super().__init__()
self.factor = factor self.factor = factor
kernel = make_kernel(kernel) kernel = make_kernel(kernel)
self.register_buffer("kernel", kernel) self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 pad0 = (p + 1) // 2
pad1 = p // 2 pad1 = p // 2
self.pad = (pad0, pad1) self.pad = (pad0, pad1)
def forward(self, input): def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) out = upfirdn2d(input,
self.kernel,
up=1,
down=self.factor,
pad=self.pad)
return out return out
class Upfirdn2dBlur(nn.Layer): class Upfirdn2dBlur(nn.Layer):
def __init__(self, kernel, pad, upsample_factor=1): def __init__(self, kernel, pad, upsample_factor=1):
super().__init__() super().__init__()
kernel = make_kernel(kernel) kernel = make_kernel(kernel)
if upsample_factor > 1: if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2) kernel = kernel * (upsample_factor * upsample_factor)
self.register_buffer("kernel", kernel) self.register_buffer("kernel", kernel, persistable=False)
self.pad = pad self.pad = pad
def forward(self, input): def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad) out = upfirdn2d(input, self.kernel, pad=self.pad)
return out return out
import librosa
import librosa.filters
import numpy as np import numpy as np
from scipy import signal from scipy import signal
from scipy.io import wavfile from scipy.io import wavfile
from paddle.utils import try_import
from .audio_config import get_audio_config from .audio_config import get_audio_config
audio_config = get_audio_config() audio_config = get_audio_config()
def load_wav(path, sr): def load_wav(path, sr):
librosa = try_import('librosa')
return librosa.core.load(path, sr=sr)[0] return librosa.core.load(path, sr=sr)[0]
...@@ -19,6 +19,7 @@ def save_wav(wav, path, sr): ...@@ -19,6 +19,7 @@ def save_wav(wav, path, sr):
def save_wavenet_wav(wav, path, sr): def save_wavenet_wav(wav, path, sr):
librosa = try_import('librosa')
librosa.output.write_wav(path, wav, sr=sr) librosa.output.write_wav(path, wav, sr=sr)
...@@ -75,6 +76,7 @@ def _stft(y): ...@@ -75,6 +76,7 @@ def _stft(y):
if audio_config.use_lws: if audio_config.use_lws:
return _lws_processor(audio_config).stft(y).T return _lws_processor(audio_config).stft(y).T
else: else:
librosa = try_import('librosa')
return librosa.stft(y=y, return librosa.stft(y=y,
n_fft=audio_config.n_fft, n_fft=audio_config.n_fft,
hop_length=get_hop_size(), hop_length=get_hop_size(),
...@@ -123,6 +125,7 @@ def _linear_to_mel(spectogram): ...@@ -123,6 +125,7 @@ def _linear_to_mel(spectogram):
def _build_mel_basis(): def _build_mel_basis():
assert audio_config.fmax <= audio_config.sample_rate // 2 assert audio_config.fmax <= audio_config.sample_rate // 2
librosa = try_import('librosa')
return librosa.filters.mel(audio_config.sample_rate, return librosa.filters.mel(audio_config.sample_rate,
audio_config.n_fft, audio_config.n_fft,
n_mels=audio_config.num_mels, n_mels=audio_config.num_mels,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册