Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
530a6a8c
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
530a6a8c
编写于
1月 22, 2021
作者:
L
LielinJiang
提交者:
GitHub
1月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add train code of stylegan2 (#149)
* add stylegan model
上级
e13e1c18
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
711 addition
and
285 deletion
+711
-285
configs/stylegan_v2_256_ffhq.yaml
configs/stylegan_v2_256_ffhq.yaml
+71
-0
ppgan/datasets/preprocess/transforms.py
ppgan/datasets/preprocess/transforms.py
+6
-4
ppgan/datasets/single_dataset.py
ppgan/datasets/single_dataset.py
+1
-1
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+14
-1
ppgan/models/__init__.py
ppgan/models/__init__.py
+1
-0
ppgan/models/discriminators/discriminator_styleganv2.py
ppgan/models/discriminators/discriminator_styleganv2.py
+60
-42
ppgan/models/generators/generator_styleganv2.py
ppgan/models/generators/generator_styleganv2.py
+160
-140
ppgan/models/styleganv2_model.py
ppgan/models/styleganv2_model.py
+282
-0
ppgan/modules/equalized.py
ppgan/modules/equalized.py
+45
-34
ppgan/modules/upfirdn2d.py
ppgan/modules/upfirdn2d.py
+66
-61
ppgan/utils/audio.py
ppgan/utils/audio.py
+5
-2
未找到文件。
configs/stylegan_v2_256_ffhq.yaml
0 → 100644
浏览文件 @
530a6a8c
total_iters
:
800000
output_dir
:
output_dir
model
:
name
:
StyleGAN2Model
generator
:
name
:
StyleGANv2Generator
size
:
256
style_dim
:
512
n_mlp
:
8
discriminator
:
name
:
StyleGANv2Discriminator
size
:
256
gan_criterion
:
name
:
GANLoss
gan_mode
:
logistic
loss_weight
:
!!float
1
# r1 regularization for discriminator
r1_reg_weight
:
10.
# path length regularization for generator
path_batch_shrink
:
2.
path_reg_weight
:
2.
params
:
gen_iters
:
4
disc_iters
:
16
dataset
:
train
:
name
:
SingleDataset
dataroot
:
data/ffhq/images256x256/
num_workers
:
3
batch_size
:
3
preprocess
:
-
name
:
LoadImageFromFile
key
:
A
-
name
:
Transforms
input_keys
:
[
A
]
pipeline
:
-
name
:
RandomHorizontalFlip
-
name
:
Transpose
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
lr_scheduler
:
name
:
MultiStepDecay
learning_rate
:
0.002
milestones
:
[
600000
]
gamma
:
0.5
optimizer
:
optimG
:
name
:
Adam
beta1
:
0.0
beta2
:
0.792
net_names
:
-
gen
optimD
:
name
:
Adam
net_names
:
-
disc
beta1
:
0.0
beta2
:
0.9317647058823529
log_config
:
interval
:
50
visiual_interval
:
500
snapshot_config
:
interval
:
5000
ppgan/datasets/preprocess/transforms.py
浏览文件 @
530a6a8c
...
@@ -59,19 +59,21 @@ class Transforms():
...
@@ -59,19 +59,21 @@ class Transforms():
data
=
tuple
(
data
)
data
=
tuple
(
data
)
for
transform
in
self
.
transforms
:
for
transform
in
self
.
transforms
:
data
=
transform
(
data
)
data
=
transform
(
data
)
if
hasattr
(
transform
,
'params'
)
and
isinstance
(
if
hasattr
(
transform
,
'params'
)
and
isinstance
(
transform
.
params
,
dict
):
transform
.
params
,
dict
):
datas
.
update
(
transform
.
params
)
datas
.
update
(
transform
.
params
)
if
len
(
self
.
input_keys
)
>
1
:
for
i
,
k
in
enumerate
(
self
.
input_keys
):
datas
[
k
]
=
data
[
i
]
else
:
datas
[
k
]
=
data
if
self
.
output_keys
is
not
None
:
if
self
.
output_keys
is
not
None
:
for
i
,
k
in
enumerate
(
self
.
output_keys
):
for
i
,
k
in
enumerate
(
self
.
output_keys
):
datas
[
k
]
=
data
[
i
]
datas
[
k
]
=
data
[
i
]
return
datas
return
datas
for
i
,
k
in
enumerate
(
self
.
input_keys
):
datas
[
k
]
=
data
[
i
]
return
datas
return
datas
...
...
ppgan/datasets/single_dataset.py
浏览文件 @
530a6a8c
...
@@ -27,7 +27,7 @@ class SingleDataset(BaseDataset):
...
@@ -27,7 +27,7 @@ class SingleDataset(BaseDataset):
dataroot (str): Directory of dataset.
dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config.
preprocess (list[dict]): A sequence of data preprocess config.
"""
"""
super
(
SingleDataset
).
__init__
(
self
,
preprocess
)
super
(
SingleDataset
,
self
).
__init__
(
preprocess
)
self
.
dataroot
=
dataroot
self
.
dataroot
=
dataroot
self
.
data_infos
=
self
.
prepare_data_infos
()
self
.
data_infos
=
self
.
prepare_data_infos
()
...
...
ppgan/engine/trainer.py
浏览文件 @
530a6a8c
...
@@ -123,6 +123,8 @@ class Trainer:
...
@@ -123,6 +123,8 @@ class Trainer:
self
.
batch_id
=
0
self
.
batch_id
=
0
self
.
global_steps
=
0
self
.
global_steps
=
0
self
.
weight_interval
=
cfg
.
snapshot_config
.
interval
self
.
weight_interval
=
cfg
.
snapshot_config
.
interval
if
self
.
by_epoch
:
self
.
weight_interval
*=
self
.
iters_per_epoch
self
.
log_interval
=
cfg
.
log_config
.
interval
self
.
log_interval
=
cfg
.
log_config
.
interval
self
.
visual_interval
=
cfg
.
log_config
.
visiual_interval
self
.
visual_interval
=
cfg
.
log_config
.
visiual_interval
if
self
.
by_epoch
:
if
self
.
by_epoch
:
...
@@ -143,6 +145,17 @@ class Trainer:
...
@@ -143,6 +145,17 @@ class Trainer:
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
self
.
model
.
nets
[
net_name
]
=
paddle
.
DataParallel
(
net
,
strategy
)
self
.
model
.
nets
[
net_name
]
=
paddle
.
DataParallel
(
net
,
strategy
)
def
learning_rate_scheduler_step
(
self
):
if
isinstance
(
self
.
model
.
lr_scheduler
,
dict
):
for
lr_scheduler
in
self
.
model
.
lr_scheduler
.
values
():
lr_scheduler
.
step
()
elif
isinstance
(
self
.
model
.
lr_scheduler
,
paddle
.
optimizer
.
lr
.
LRScheduler
):
self
.
model
.
lr_scheduler
.
step
()
else
:
raise
ValueError
(
'lr schedulter must be a dict or an instance of LRScheduler'
)
def
train
(
self
):
def
train
(
self
):
reader_cost_averager
=
TimeAverager
()
reader_cost_averager
=
TimeAverager
()
batch_cost_averager
=
TimeAverager
()
batch_cost_averager
=
TimeAverager
()
...
@@ -179,7 +192,7 @@ class Trainer:
...
@@ -179,7 +192,7 @@ class Trainer:
if
self
.
current_iter
%
self
.
visual_interval
==
0
:
if
self
.
current_iter
%
self
.
visual_interval
==
0
:
self
.
visual
(
'visual_train'
)
self
.
visual
(
'visual_train'
)
self
.
model
.
lr_scheduler
.
step
()
self
.
learning_rate_scheduler_
step
()
if
self
.
validate_interval
>
-
1
and
self
.
current_iter
%
self
.
validate_interval
==
0
:
if
self
.
validate_interval
>
-
1
and
self
.
current_iter
%
self
.
validate_interval
==
0
:
self
.
test
()
self
.
test
()
...
...
ppgan/models/__init__.py
浏览文件 @
530a6a8c
...
@@ -22,5 +22,6 @@ from .esrgan_model import ESRGAN
...
@@ -22,5 +22,6 @@ from .esrgan_model import ESRGAN
from
.ugatit_model
import
UGATITModel
from
.ugatit_model
import
UGATITModel
from
.dc_gan_model
import
DCGANModel
from
.dc_gan_model
import
DCGANModel
from
.animeganv2_model
import
AnimeGANV2Model
,
AnimeGANV2PreTrainModel
from
.animeganv2_model
import
AnimeGANV2Model
,
AnimeGANV2PreTrainModel
from
.styleganv2_model
import
StyleGAN2Model
from
.wav2lip_model
import
Wav2LipModel
from
.wav2lip_model
import
Wav2LipModel
from
.wav2lip_hq_model
import
Wav2LipModelHq
from
.wav2lip_hq_model
import
Wav2LipModelHq
ppgan/models/discriminators/discriminator_styleganv2.py
浏览文件 @
530a6a8c
...
@@ -35,22 +35,22 @@ class ConvLayer(nn.Sequential):
...
@@ -35,22 +35,22 @@ class ConvLayer(nn.Sequential):
activate
=
True
,
activate
=
True
,
):
):
layers
=
[]
layers
=
[]
if
downsample
:
if
downsample
:
factor
=
2
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
pad1
=
p
//
2
layers
.
append
(
Upfirdn2dBlur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
)))
layers
.
append
(
Upfirdn2dBlur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
)))
stride
=
2
stride
=
2
self
.
padding
=
0
self
.
padding
=
0
else
:
else
:
stride
=
1
stride
=
1
self
.
padding
=
kernel_size
//
2
self
.
padding
=
kernel_size
//
2
layers
.
append
(
layers
.
append
(
EqualConv2D
(
EqualConv2D
(
in_channel
,
in_channel
,
...
@@ -59,41 +59,58 @@ class ConvLayer(nn.Sequential):
...
@@ -59,41 +59,58 @@ class ConvLayer(nn.Sequential):
padding
=
self
.
padding
,
padding
=
self
.
padding
,
stride
=
stride
,
stride
=
stride
,
bias
=
bias
and
not
activate
,
bias
=
bias
and
not
activate
,
)
))
)
if
activate
:
if
activate
:
layers
.
append
(
FusedLeakyReLU
(
out_channel
,
bias
=
bias
))
layers
.
append
(
FusedLeakyReLU
(
out_channel
,
bias
=
bias
))
super
().
__init__
(
*
layers
)
super
().
__init__
(
*
layers
)
class
ResBlock
(
nn
.
Layer
):
class
ResBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channel
,
out_channel
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
def
__init__
(
self
,
in_channel
,
out_channel
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
ConvLayer
(
in_channel
,
in_channel
,
3
)
self
.
conv1
=
ConvLayer
(
in_channel
,
in_channel
,
3
)
self
.
conv2
=
ConvLayer
(
in_channel
,
out_channel
,
3
,
downsample
=
True
)
self
.
conv2
=
ConvLayer
(
in_channel
,
out_channel
,
3
,
downsample
=
True
)
self
.
skip
=
ConvLayer
(
self
.
skip
=
ConvLayer
(
in_channel
,
in_channel
,
out_channel
,
1
,
downsample
=
True
,
activate
=
False
,
bias
=
False
out_channel
,
)
1
,
downsample
=
True
,
activate
=
False
,
bias
=
False
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
out
=
self
.
conv1
(
input
)
out
=
self
.
conv1
(
input
)
out
=
self
.
conv2
(
out
)
out
=
self
.
conv2
(
out
)
skip
=
self
.
skip
(
input
)
skip
=
self
.
skip
(
input
)
out
=
(
out
+
skip
)
/
math
.
sqrt
(
2
)
out
=
(
out
+
skip
)
/
math
.
sqrt
(
2
)
return
out
return
out
# temporally solve pow double grad problem
def
var
(
x
,
axis
=
None
,
unbiased
=
True
,
keepdim
=
False
,
name
=
None
):
u
=
paddle
.
mean
(
x
,
axis
,
True
,
name
)
out
=
paddle
.
sum
((
x
-
u
)
*
(
x
-
u
),
axis
,
keepdim
=
keepdim
,
name
=
name
)
n
=
paddle
.
cast
(
paddle
.
numel
(
x
),
x
.
dtype
)
\
/
paddle
.
cast
(
paddle
.
numel
(
out
),
x
.
dtype
)
if
unbiased
:
one_const
=
paddle
.
ones
([
1
],
x
.
dtype
)
n
=
paddle
.
where
(
n
>
one_const
,
n
-
1.
,
one_const
)
out
/=
n
return
out
@
DISCRIMINATORS
.
register
()
@
DISCRIMINATORS
.
register
()
class
StyleGANv2Discriminator
(
nn
.
Layer
):
class
StyleGANv2Discriminator
(
nn
.
Layer
):
def
__init__
(
self
,
size
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
def
__init__
(
self
,
size
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
super
().
__init__
()
super
().
__init__
()
channels
=
{
channels
=
{
4
:
512
,
4
:
512
,
8
:
512
,
8
:
512
,
...
@@ -105,47 +122,48 @@ class StyleGANv2Discriminator(nn.Layer):
...
@@ -105,47 +122,48 @@ class StyleGANv2Discriminator(nn.Layer):
512
:
32
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
}
convs
=
[
ConvLayer
(
3
,
channels
[
size
],
1
)]
convs
=
[
ConvLayer
(
3
,
channels
[
size
],
1
)]
log_size
=
int
(
math
.
log
(
size
,
2
))
log_size
=
int
(
math
.
log
(
size
,
2
))
in_channel
=
channels
[
size
]
in_channel
=
channels
[
size
]
for
i
in
range
(
log_size
,
2
,
-
1
):
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
out_channel
=
channels
[
2
**
(
i
-
1
)]
convs
.
append
(
ResBlock
(
in_channel
,
out_channel
,
blur_kernel
))
convs
.
append
(
ResBlock
(
in_channel
,
out_channel
,
blur_kernel
))
in_channel
=
out_channel
in_channel
=
out_channel
self
.
convs
=
nn
.
Sequential
(
*
convs
)
self
.
convs
=
nn
.
Sequential
(
*
convs
)
self
.
stddev_group
=
4
self
.
stddev_group
=
4
self
.
stddev_feat
=
1
self
.
stddev_feat
=
1
self
.
final_conv
=
ConvLayer
(
in_channel
+
1
,
channels
[
4
],
3
)
self
.
final_conv
=
ConvLayer
(
in_channel
+
1
,
channels
[
4
],
3
)
self
.
final_linear
=
nn
.
Sequential
(
self
.
final_linear
=
nn
.
Sequential
(
EqualLinear
(
channels
[
4
]
*
4
*
4
,
channels
[
4
],
activation
=
"fused_lrelu"
),
EqualLinear
(
channels
[
4
]
*
4
*
4
,
channels
[
4
],
activation
=
"fused_lrelu"
),
EqualLinear
(
channels
[
4
],
1
),
EqualLinear
(
channels
[
4
],
1
),
)
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
out
=
self
.
convs
(
input
)
out
=
self
.
convs
(
input
)
batch
,
channel
,
height
,
width
=
out
.
shape
batch
,
channel
,
height
,
width
=
out
.
shape
group
=
min
(
batch
,
self
.
stddev_group
)
group
=
min
(
batch
,
self
.
stddev_group
)
stddev
=
out
.
reshape
((
stddev
=
out
.
reshape
((
group
,
-
1
,
self
.
stddev_feat
,
group
,
-
1
,
self
.
stddev_feat
,
channel
//
self
.
stddev_feat
,
height
,
width
channel
//
self
.
stddev_feat
,
height
,
width
))
))
stddev
=
paddle
.
sqrt
(
var
(
stddev
,
0
,
unbiased
=
False
)
+
1e-8
)
stddev
=
paddle
.
sqrt
(
stddev
.
var
(
0
,
unbiased
=
False
)
+
1e-8
)
stddev
=
stddev
.
mean
([
2
,
3
,
4
],
keepdim
=
True
).
squeeze
(
2
)
stddev
=
stddev
.
mean
([
2
,
3
,
4
],
keepdim
=
True
).
squeeze
(
2
)
stddev
=
stddev
.
tile
((
group
,
1
,
height
,
width
))
stddev
=
stddev
.
tile
((
group
,
1
,
height
,
width
))
out
=
paddle
.
concat
([
out
,
stddev
],
1
)
out
=
paddle
.
concat
([
out
,
stddev
],
1
)
out
=
self
.
final_conv
(
out
)
out
=
self
.
final_conv
(
out
)
out
=
out
.
reshape
((
batch
,
-
1
))
out
=
out
.
reshape
((
batch
,
-
1
))
out
=
self
.
final_linear
(
out
)
out
=
self
.
final_linear
(
out
)
return
out
return
out
ppgan/models/generators/generator_styleganv2.py
浏览文件 @
530a6a8c
...
@@ -27,11 +27,12 @@ from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur
...
@@ -27,11 +27,12 @@ from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur
class
PixelNorm
(
nn
.
Layer
):
class
PixelNorm
(
nn
.
Layer
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
return
input
*
paddle
.
rsqrt
(
paddle
.
mean
(
input
**
2
,
1
,
keepdim
=
True
)
+
1e-8
)
return
input
*
paddle
.
rsqrt
(
paddle
.
mean
(
input
*
input
,
1
,
keepdim
=
True
)
+
1e-8
)
class
ModulatedConv2D
(
nn
.
Layer
):
class
ModulatedConv2D
(
nn
.
Layer
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -45,75 +46,78 @@ class ModulatedConv2D(nn.Layer):
...
@@ -45,75 +46,78 @@ class ModulatedConv2D(nn.Layer):
blur_kernel
=
[
1
,
3
,
3
,
1
],
blur_kernel
=
[
1
,
3
,
3
,
1
],
):
):
super
().
__init__
()
super
().
__init__
()
self
.
eps
=
1e-8
self
.
eps
=
1e-8
self
.
kernel_size
=
kernel_size
self
.
kernel_size
=
kernel_size
self
.
in_channel
=
in_channel
self
.
in_channel
=
in_channel
self
.
out_channel
=
out_channel
self
.
out_channel
=
out_channel
self
.
upsample
=
upsample
self
.
upsample
=
upsample
self
.
downsample
=
downsample
self
.
downsample
=
downsample
if
upsample
:
if
upsample
:
factor
=
2
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
-
(
kernel_size
-
1
)
p
=
(
len
(
blur_kernel
)
-
factor
)
-
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
+
1
pad1
=
p
//
2
+
1
self
.
blur
=
Upfirdn2dBlur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
),
upsample_factor
=
factor
)
self
.
blur
=
Upfirdn2dBlur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
),
upsample_factor
=
factor
)
if
downsample
:
if
downsample
:
factor
=
2
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
pad1
=
p
//
2
self
.
blur
=
Upfirdn2dBlur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
))
self
.
blur
=
Upfirdn2dBlur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
))
fan_in
=
in_channel
*
kernel_size
**
2
fan_in
=
in_channel
*
(
kernel_size
*
kernel_size
)
self
.
scale
=
1
/
math
.
sqrt
(
fan_in
)
self
.
scale
=
1
/
math
.
sqrt
(
fan_in
)
self
.
padding
=
kernel_size
//
2
self
.
padding
=
kernel_size
//
2
self
.
weight
=
self
.
create_parameter
(
self
.
weight
=
self
.
create_parameter
(
(
1
,
out_channel
,
in_channel
,
kernel_size
,
kernel_size
),
default_initializer
=
nn
.
initializer
.
Normal
()
(
1
,
out_channel
,
in_channel
,
kernel_size
,
kernel_size
),
)
default_initializer
=
nn
.
initializer
.
Normal
()
)
self
.
modulation
=
EqualLinear
(
style_dim
,
in_channel
,
bias_init
=
1
)
self
.
modulation
=
EqualLinear
(
style_dim
,
in_channel
,
bias_init
=
1
)
self
.
demodulate
=
demodulate
self
.
demodulate
=
demodulate
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
in_channel
}
,
{
self
.
out_channel
}
,
{
self
.
kernel_size
}
, "
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
in_channel
}
,
{
self
.
out_channel
}
,
{
self
.
kernel_size
}
, "
f
"upsample=
{
self
.
upsample
}
, downsample=
{
self
.
downsample
}
)"
f
"upsample=
{
self
.
upsample
}
, downsample=
{
self
.
downsample
}
)"
)
)
def
forward
(
self
,
input
,
style
):
def
forward
(
self
,
input
,
style
):
batch
,
in_channel
,
height
,
width
=
input
.
shape
batch
,
in_channel
,
height
,
width
=
input
.
shape
style
=
self
.
modulation
(
style
).
reshape
((
batch
,
1
,
in_channel
,
1
,
1
))
style
=
self
.
modulation
(
style
).
reshape
((
batch
,
1
,
in_channel
,
1
,
1
))
weight
=
self
.
scale
*
self
.
weight
*
style
weight
=
self
.
scale
*
self
.
weight
*
style
if
self
.
demodulate
:
if
self
.
demodulate
:
demod
=
paddle
.
rsqrt
(
weight
.
pow
(
2
).
sum
([
2
,
3
,
4
])
+
1e-8
)
demod
=
paddle
.
rsqrt
(
(
weight
*
weight
).
sum
([
2
,
3
,
4
])
+
1e-8
)
weight
=
weight
*
demod
.
reshape
((
batch
,
self
.
out_channel
,
1
,
1
,
1
))
weight
=
weight
*
demod
.
reshape
((
batch
,
self
.
out_channel
,
1
,
1
,
1
))
weight
=
weight
.
reshape
((
weight
=
weight
.
reshape
((
batch
*
self
.
out_channel
,
in_channel
,
batch
*
self
.
out_channel
,
in_channel
,
self
.
kernel_size
,
self
.
kernel_size
self
.
kernel_size
,
self
.
kernel_size
))
))
if
self
.
upsample
:
if
self
.
upsample
:
input
=
input
.
reshape
((
1
,
batch
*
in_channel
,
height
,
width
))
input
=
input
.
reshape
((
1
,
batch
*
in_channel
,
height
,
width
))
weight
=
weight
.
reshape
((
weight
=
weight
.
reshape
((
batch
,
self
.
out_channel
,
in_channel
,
batch
,
self
.
out_channel
,
in_channel
,
self
.
kernel_size
,
self
.
kernel_size
self
.
kernel_size
,
self
.
kernel_size
))
))
weight
=
weight
.
transpose
((
0
,
2
,
1
,
3
,
4
)).
reshape
(
weight
=
weight
.
transpose
((
0
,
2
,
1
,
3
,
4
)).
reshape
((
(
batch
*
in_channel
,
self
.
out_channel
,
self
.
kernel_size
,
batch
*
in_channel
,
self
.
out_channel
,
self
.
kernel_size
,
self
.
kernel_size
self
.
kernel_size
))
))
out
=
F
.
conv2d_transpose
(
input
,
out
=
F
.
conv2d_transpose
(
input
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
batch
)
weight
,
padding
=
0
,
stride
=
2
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
reshape
((
batch
,
self
.
out_channel
,
height
,
width
))
out
=
out
.
reshape
((
batch
,
self
.
out_channel
,
height
,
width
))
out
=
self
.
blur
(
out
)
out
=
self
.
blur
(
out
)
elif
self
.
downsample
:
elif
self
.
downsample
:
input
=
self
.
blur
(
input
)
input
=
self
.
blur
(
input
)
_
,
_
,
height
,
width
=
input
.
shape
_
,
_
,
height
,
width
=
input
.
shape
...
@@ -121,43 +125,46 @@ class ModulatedConv2D(nn.Layer):
...
@@ -121,43 +125,46 @@ class ModulatedConv2D(nn.Layer):
out
=
F
.
conv2d
(
input
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
batch
)
out
=
F
.
conv2d
(
input
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
reshape
((
batch
,
self
.
out_channel
,
height
,
width
))
out
=
out
.
reshape
((
batch
,
self
.
out_channel
,
height
,
width
))
else
:
else
:
input
=
input
.
reshape
((
1
,
batch
*
in_channel
,
height
,
width
))
input
=
input
.
reshape
((
1
,
batch
*
in_channel
,
height
,
width
))
out
=
F
.
conv2d
(
input
,
weight
,
padding
=
self
.
padding
,
groups
=
batch
)
out
=
F
.
conv2d
(
input
,
weight
,
padding
=
self
.
padding
,
groups
=
batch
)
_
,
_
,
height
,
width
=
out
.
shape
_
,
_
,
height
,
width
=
out
.
shape
out
=
out
.
reshape
((
batch
,
self
.
out_channel
,
height
,
width
))
out
=
out
.
reshape
((
batch
,
self
.
out_channel
,
height
,
width
))
return
out
return
out
class
NoiseInjection
(
nn
.
Layer
):
class
NoiseInjection
(
nn
.
Layer
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
self
.
create_parameter
((
1
,),
default_initializer
=
nn
.
initializer
.
Constant
(
0.0
))
self
.
weight
=
self
.
create_parameter
(
(
1
,
),
default_initializer
=
nn
.
initializer
.
Constant
(
0.0
))
def
forward
(
self
,
image
,
noise
=
None
):
def
forward
(
self
,
image
,
noise
=
None
):
if
noise
is
None
:
if
noise
is
None
:
batch
,
_
,
height
,
width
=
image
.
shape
batch
,
_
,
height
,
width
=
image
.
shape
noise
=
paddle
.
randn
((
batch
,
1
,
height
,
width
))
noise
=
paddle
.
randn
((
batch
,
1
,
height
,
width
))
return
image
+
self
.
weight
*
noise
return
image
+
self
.
weight
*
noise
class
ConstantInput
(
nn
.
Layer
):
class
ConstantInput
(
nn
.
Layer
):
def
__init__
(
self
,
channel
,
size
=
4
):
def
__init__
(
self
,
channel
,
size
=
4
):
super
().
__init__
()
super
().
__init__
()
self
.
input
=
self
.
create_parameter
((
1
,
channel
,
size
,
size
),
default_initializer
=
nn
.
initializer
.
Normal
())
self
.
input
=
self
.
create_parameter
(
(
1
,
channel
,
size
,
size
),
default_initializer
=
nn
.
initializer
.
Normal
())
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
batch
=
input
.
shape
[
0
]
batch
=
input
.
shape
[
0
]
out
=
self
.
input
.
tile
((
batch
,
1
,
1
,
1
))
out
=
self
.
input
.
tile
((
batch
,
1
,
1
,
1
))
return
out
return
out
class
StyledConv
(
nn
.
Layer
):
class
StyledConv
(
nn
.
Layer
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -170,7 +177,7 @@ class StyledConv(nn.Layer):
...
@@ -170,7 +177,7 @@ class StyledConv(nn.Layer):
demodulate
=
True
,
demodulate
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
conv
=
ModulatedConv2D
(
self
.
conv
=
ModulatedConv2D
(
in_channel
,
in_channel
,
out_channel
,
out_channel
,
...
@@ -180,40 +187,49 @@ class StyledConv(nn.Layer):
...
@@ -180,40 +187,49 @@ class StyledConv(nn.Layer):
blur_kernel
=
blur_kernel
,
blur_kernel
=
blur_kernel
,
demodulate
=
demodulate
,
demodulate
=
demodulate
,
)
)
self
.
noise
=
NoiseInjection
()
self
.
noise
=
NoiseInjection
()
self
.
activate
=
FusedLeakyReLU
(
out_channel
)
self
.
activate
=
FusedLeakyReLU
(
out_channel
)
def
forward
(
self
,
input
,
style
,
noise
=
None
):
def
forward
(
self
,
input
,
style
,
noise
=
None
):
out
=
self
.
conv
(
input
,
style
)
out
=
self
.
conv
(
input
,
style
)
out
=
self
.
noise
(
out
,
noise
=
noise
)
out
=
self
.
noise
(
out
,
noise
=
noise
)
out
=
self
.
activate
(
out
)
out
=
self
.
activate
(
out
)
return
out
return
out
class
ToRGB
(
nn
.
Layer
):
class
ToRGB
(
nn
.
Layer
):
def
__init__
(
self
,
in_channel
,
style_dim
,
upsample
=
True
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
def
__init__
(
self
,
in_channel
,
style_dim
,
upsample
=
True
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
super
().
__init__
()
super
().
__init__
()
if
upsample
:
if
upsample
:
self
.
upsample
=
Upfirdn2dUpsample
(
blur_kernel
)
self
.
upsample
=
Upfirdn2dUpsample
(
blur_kernel
)
self
.
conv
=
ModulatedConv2D
(
in_channel
,
3
,
1
,
style_dim
,
demodulate
=
False
)
self
.
conv
=
ModulatedConv2D
(
in_channel
,
self
.
bias
=
self
.
create_parameter
((
1
,
3
,
1
,
1
),
nn
.
initializer
.
Constant
(
0.0
))
3
,
1
,
style_dim
,
demodulate
=
False
)
self
.
bias
=
self
.
create_parameter
((
1
,
3
,
1
,
1
),
nn
.
initializer
.
Constant
(
0.0
))
def
forward
(
self
,
input
,
style
,
skip
=
None
):
def
forward
(
self
,
input
,
style
,
skip
=
None
):
out
=
self
.
conv
(
input
,
style
)
out
=
self
.
conv
(
input
,
style
)
out
=
out
+
self
.
bias
out
=
out
+
self
.
bias
if
skip
is
not
None
:
if
skip
is
not
None
:
skip
=
self
.
upsample
(
skip
)
skip
=
self
.
upsample
(
skip
)
out
=
out
+
skip
out
=
out
+
skip
return
out
return
out
@
GENERATORS
.
register
()
@
GENERATORS
.
register
()
class
StyleGANv2Generator
(
nn
.
Layer
):
class
StyleGANv2Generator
(
nn
.
Layer
):
def
__init__
(
def
__init__
(
...
@@ -226,22 +242,22 @@ class StyleGANv2Generator(nn.Layer):
...
@@ -226,22 +242,22 @@ class StyleGANv2Generator(nn.Layer):
lr_mlp
=
0.01
,
lr_mlp
=
0.01
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
size
=
size
self
.
size
=
size
self
.
style_dim
=
style_dim
self
.
style_dim
=
style_dim
layers
=
[
PixelNorm
()]
layers
=
[
PixelNorm
()]
for
i
in
range
(
n_mlp
):
for
i
in
range
(
n_mlp
):
layers
.
append
(
layers
.
append
(
EqualLinear
(
EqualLinear
(
style_dim
,
style_dim
,
style_dim
,
lr_mul
=
lr_mlp
,
activation
=
"fused_lrelu"
style_dim
,
)
lr_mul
=
lr_mlp
,
)
activation
=
"fused_lrelu"
)
)
self
.
style
=
nn
.
Sequential
(
*
layers
)
self
.
style
=
nn
.
Sequential
(
*
layers
)
self
.
channels
=
{
self
.
channels
=
{
4
:
512
,
4
:
512
,
8
:
512
,
8
:
512
,
...
@@ -253,31 +269,34 @@ class StyleGANv2Generator(nn.Layer):
...
@@ -253,31 +269,34 @@ class StyleGANv2Generator(nn.Layer):
512
:
32
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
}
self
.
input
=
ConstantInput
(
self
.
channels
[
4
])
self
.
input
=
ConstantInput
(
self
.
channels
[
4
])
self
.
conv1
=
StyledConv
(
self
.
conv1
=
StyledConv
(
self
.
channels
[
4
],
self
.
channels
[
4
],
self
.
channels
[
4
],
3
,
style_dim
,
blur_kernel
=
blur_kernel
self
.
channels
[
4
],
)
3
,
style_dim
,
blur_kernel
=
blur_kernel
)
self
.
to_rgb1
=
ToRGB
(
self
.
channels
[
4
],
style_dim
,
upsample
=
False
)
self
.
to_rgb1
=
ToRGB
(
self
.
channels
[
4
],
style_dim
,
upsample
=
False
)
self
.
log_size
=
int
(
math
.
log
(
size
,
2
))
self
.
log_size
=
int
(
math
.
log
(
size
,
2
))
self
.
num_layers
=
(
self
.
log_size
-
2
)
*
2
+
1
self
.
num_layers
=
(
self
.
log_size
-
2
)
*
2
+
1
self
.
convs
=
nn
.
LayerList
()
self
.
convs
=
nn
.
LayerList
()
self
.
upsamples
=
nn
.
LayerList
()
self
.
upsamples
=
nn
.
LayerList
()
self
.
to_rgbs
=
nn
.
LayerList
()
self
.
to_rgbs
=
nn
.
LayerList
()
self
.
noises
=
nn
.
Layer
()
self
.
noises
=
nn
.
Layer
()
in_channel
=
self
.
channels
[
4
]
in_channel
=
self
.
channels
[
4
]
for
layer_idx
in
range
(
self
.
num_layers
):
for
layer_idx
in
range
(
self
.
num_layers
):
res
=
(
layer_idx
+
5
)
//
2
res
=
(
layer_idx
+
5
)
//
2
shape
=
[
1
,
1
,
2
**
res
,
2
**
res
]
shape
=
[
1
,
1
,
2
**
res
,
2
**
res
]
self
.
noises
.
register_buffer
(
f
"noise_
{
layer_idx
}
"
,
paddle
.
randn
(
shape
))
self
.
noises
.
register_buffer
(
f
"noise_
{
layer_idx
}
"
,
paddle
.
randn
(
shape
))
for
i
in
range
(
3
,
self
.
log_size
+
1
):
for
i
in
range
(
3
,
self
.
log_size
+
1
):
out_channel
=
self
.
channels
[
2
**
i
]
out_channel
=
self
.
channels
[
2
**
i
]
self
.
convs
.
append
(
self
.
convs
.
append
(
StyledConv
(
StyledConv
(
in_channel
,
in_channel
,
...
@@ -286,41 +305,39 @@ class StyleGANv2Generator(nn.Layer):
...
@@ -286,41 +305,39 @@ class StyleGANv2Generator(nn.Layer):
style_dim
,
style_dim
,
upsample
=
True
,
upsample
=
True
,
blur_kernel
=
blur_kernel
,
blur_kernel
=
blur_kernel
,
)
))
)
self
.
convs
.
append
(
self
.
convs
.
append
(
StyledConv
(
StyledConv
(
out_channel
,
out_channel
,
out_channel
,
3
,
style_dim
,
blur_kernel
=
blur_kernel
out_channel
,
)
3
,
)
style_dim
,
blur_kernel
=
blur_kernel
))
self
.
to_rgbs
.
append
(
ToRGB
(
out_channel
,
style_dim
))
self
.
to_rgbs
.
append
(
ToRGB
(
out_channel
,
style_dim
))
in_channel
=
out_channel
in_channel
=
out_channel
self
.
n_latent
=
self
.
log_size
*
2
-
2
self
.
n_latent
=
self
.
log_size
*
2
-
2
def
make_noise
(
self
):
def
make_noise
(
self
):
noises
=
[
paddle
.
randn
((
1
,
1
,
2
**
2
,
2
**
2
))]
noises
=
[
paddle
.
randn
((
1
,
1
,
2
**
2
,
2
**
2
))]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
for
i
in
range
(
3
,
self
.
log_size
+
1
):
for
_
in
range
(
2
):
for
_
in
range
(
2
):
noises
.
append
(
paddle
.
randn
((
1
,
1
,
2
**
i
,
2
**
i
)))
noises
.
append
(
paddle
.
randn
((
1
,
1
,
2
**
i
,
2
**
i
)))
return
noises
return
noises
def
mean_latent
(
self
,
n_latent
):
def
mean_latent
(
self
,
n_latent
):
latent_in
=
paddle
.
randn
((
latent_in
=
paddle
.
randn
((
n_latent
,
self
.
style_dim
))
n_latent
,
self
.
style_dim
))
latent
=
self
.
style
(
latent_in
).
mean
(
0
,
keepdim
=
True
)
latent
=
self
.
style
(
latent_in
).
mean
(
0
,
keepdim
=
True
)
return
latent
return
latent
def
get_latent
(
self
,
input
):
def
get_latent
(
self
,
input
):
return
self
.
style
(
input
)
return
self
.
style
(
input
)
def
forward
(
def
forward
(
self
,
self
,
styles
,
styles
,
...
@@ -334,62 +351,65 @@ class StyleGANv2Generator(nn.Layer):
...
@@ -334,62 +351,65 @@ class StyleGANv2Generator(nn.Layer):
):
):
if
not
input_is_latent
:
if
not
input_is_latent
:
styles
=
[
self
.
style
(
s
)
for
s
in
styles
]
styles
=
[
self
.
style
(
s
)
for
s
in
styles
]
if
noise
is
None
:
if
noise
is
None
:
if
randomize_noise
:
if
randomize_noise
:
noise
=
[
None
]
*
self
.
num_layers
noise
=
[
None
]
*
self
.
num_layers
else
:
else
:
noise
=
[
noise
=
[
getattr
(
self
.
noises
,
f
"noise_
{
i
}
"
)
for
i
in
range
(
self
.
num_layers
)
getattr
(
self
.
noises
,
f
"noise_
{
i
}
"
)
for
i
in
range
(
self
.
num_layers
)
]
]
if
truncation
<
1
:
if
truncation
<
1
:
style_t
=
[]
style_t
=
[]
for
style
in
styles
:
for
style
in
styles
:
style_t
.
append
(
style_t
.
append
(
truncation_latent
+
truncation
*
truncation_latent
+
truncation
*
(
style
-
truncation_latent
)
(
style
-
truncation_latent
))
)
styles
=
style_t
styles
=
style_t
if
len
(
styles
)
<
2
:
if
len
(
styles
)
<
2
:
inject_index
=
self
.
n_latent
inject_index
=
self
.
n_latent
if
styles
[
0
].
ndim
<
3
:
if
styles
[
0
].
ndim
<
3
:
latent
=
styles
[
0
].
unsqueeze
(
1
).
tile
((
1
,
inject_index
,
1
))
latent
=
styles
[
0
].
unsqueeze
(
1
).
tile
((
1
,
inject_index
,
1
))
else
:
else
:
latent
=
styles
[
0
]
latent
=
styles
[
0
]
else
:
else
:
if
inject_index
is
None
:
if
inject_index
is
None
:
inject_index
=
random
.
randint
(
1
,
self
.
n_latent
-
1
)
inject_index
=
random
.
randint
(
1
,
self
.
n_latent
-
1
)
latent
=
styles
[
0
].
unsqueeze
(
1
).
tile
((
1
,
inject_index
,
1
))
latent
=
styles
[
0
].
unsqueeze
(
1
).
tile
((
1
,
inject_index
,
1
))
latent2
=
styles
[
1
].
unsqueeze
(
1
).
tile
((
1
,
self
.
n_latent
-
inject_index
,
1
))
latent2
=
styles
[
1
].
unsqueeze
(
1
).
tile
(
(
1
,
self
.
n_latent
-
inject_index
,
1
))
latent
=
paddle
.
concat
([
latent
,
latent2
],
1
)
latent
=
paddle
.
concat
([
latent
,
latent2
],
1
)
out
=
self
.
input
(
latent
)
out
=
self
.
input
(
latent
)
out
=
self
.
conv1
(
out
,
latent
[:,
0
],
noise
=
noise
[
0
])
out
=
self
.
conv1
(
out
,
latent
[:,
0
],
noise
=
noise
[
0
])
skip
=
self
.
to_rgb1
(
out
,
latent
[:,
1
])
skip
=
self
.
to_rgb1
(
out
,
latent
[:,
1
])
i
=
1
i
=
1
for
conv1
,
conv2
,
noise1
,
noise2
,
to_rgb
in
zip
(
for
conv1
,
conv2
,
noise1
,
noise2
,
to_rgb
in
zip
(
self
.
convs
[::
2
],
self
.
convs
[::
2
],
self
.
convs
[
1
::
2
],
noise
[
1
::
2
],
noise
[
2
::
2
],
self
.
to_rgbs
self
.
convs
[
1
::
2
],
):
noise
[
1
::
2
],
noise
[
2
::
2
],
self
.
to_rgbs
):
out
=
conv1
(
out
,
latent
[:,
i
],
noise
=
noise1
)
out
=
conv1
(
out
,
latent
[:,
i
],
noise
=
noise1
)
out
=
conv2
(
out
,
latent
[:,
i
+
1
],
noise
=
noise2
)
out
=
conv2
(
out
,
latent
[:,
i
+
1
],
noise
=
noise2
)
skip
=
to_rgb
(
out
,
latent
[:,
i
+
2
],
skip
)
skip
=
to_rgb
(
out
,
latent
[:,
i
+
2
],
skip
)
i
+=
2
i
+=
2
image
=
skip
image
=
skip
if
return_latents
:
if
return_latents
:
return
image
,
latent
return
image
,
latent
else
:
else
:
return
image
,
None
return
image
,
None
ppgan/models/styleganv2_model.py
0 → 100644
浏览文件 @
530a6a8c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
random
import
paddle
import
paddle.nn
as
nn
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.criterions
import
build_criterion
from
.generators.builder
import
build_generator
from
.discriminators.builder
import
build_discriminator
from
..solver
import
build_lr_scheduler
,
build_optimizer
def
r1_penalty
(
real_pred
,
real_img
):
"""
R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real
=
paddle
.
grad
(
outputs
=
real_pred
.
sum
(),
inputs
=
real_img
,
create_graph
=
True
)[
0
]
grad_penalty
=
(
grad_real
*
grad_real
).
reshape
([
grad_real
.
shape
[
0
],
-
1
]).
sum
(
1
).
mean
()
return
grad_penalty
def
g_path_regularize
(
fake_img
,
latents
,
mean_path_length
,
decay
=
0.01
):
noise
=
paddle
.
randn
(
fake_img
.
shape
)
/
math
.
sqrt
(
fake_img
.
shape
[
2
]
*
fake_img
.
shape
[
3
])
grad
=
paddle
.
grad
(
outputs
=
(
fake_img
*
noise
).
sum
(),
inputs
=
latents
,
create_graph
=
True
)[
0
]
path_lengths
=
paddle
.
sqrt
((
grad
*
grad
).
sum
(
2
).
mean
(
1
))
path_mean
=
mean_path_length
+
decay
*
(
path_lengths
.
mean
()
-
mean_path_length
)
path_penalty
=
((
path_lengths
-
path_mean
)
*
(
path_lengths
-
path_mean
)).
mean
()
return
path_penalty
,
path_lengths
.
detach
().
mean
(),
path_mean
.
detach
()
@
MODELS
.
register
()
class
StyleGAN2Model
(
BaseModel
):
"""
This class implements the StyleGANV2 model, for learning image-to-image translation without paired data.
StyleGAN2 paper: https://arxiv.org/pdf/1912.04958.pdf
"""
def
__init__
(
self
,
generator
,
discriminator
=
None
,
gan_criterion
=
None
,
num_style_feat
=
512
,
mixing_prob
=
0.9
,
r1_reg_weight
=
10.
,
path_reg_weight
=
2.
,
path_batch_shrink
=
2.
,
params
=
None
):
"""Initialize the CycleGAN class.
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
gan_criterion (dict): config of gan criterion.
"""
super
(
StyleGAN2Model
,
self
).
__init__
(
params
)
self
.
gen_iters
=
4
if
self
.
params
is
None
else
self
.
params
.
get
(
'gen_iters'
,
4
)
self
.
disc_iters
=
16
if
self
.
params
is
None
else
self
.
params
.
get
(
'disc_iters'
,
16
)
self
.
disc_start_iters
=
(
0
if
self
.
params
is
None
else
self
.
params
.
get
(
'disc_start_iters'
,
0
))
self
.
visual_iters
=
(
500
if
self
.
params
is
None
else
self
.
params
.
get
(
'visual_iters'
,
500
))
self
.
mixing_prob
=
mixing_prob
self
.
num_style_feat
=
num_style_feat
self
.
r1_reg_weight
=
r1_reg_weight
self
.
path_reg_weight
=
path_reg_weight
self
.
path_batch_shrink
=
path_batch_shrink
self
.
mean_path_length
=
0
self
.
nets
[
'gen'
]
=
build_generator
(
generator
)
# define discriminators
if
discriminator
:
self
.
nets
[
'disc'
]
=
build_discriminator
(
discriminator
)
self
.
nets
[
'gen_ema'
]
=
build_generator
(
generator
)
self
.
model_ema
(
0
)
self
.
nets
[
'gen'
].
train
()
self
.
nets
[
'gen_ema'
].
eval
()
self
.
nets
[
'disc'
].
train
()
self
.
current_iter
=
1
# define loss functions
if
gan_criterion
:
self
.
gan_criterion
=
build_criterion
(
gan_criterion
)
def
setup_lr_schedulers
(
self
,
cfg
):
self
.
lr_scheduler
=
dict
()
gen_cfg
=
cfg
.
copy
()
net_g_reg_ratio
=
self
.
gen_iters
/
(
self
.
gen_iters
+
1
)
gen_cfg
[
'learning_rate'
]
=
cfg
[
'learning_rate'
]
*
net_g_reg_ratio
self
.
lr_scheduler
[
'gen'
]
=
build_lr_scheduler
(
gen_cfg
)
disc_cfg
=
cfg
.
copy
()
net_d_reg_ratio
=
self
.
disc_iters
/
(
self
.
disc_iters
+
1
)
disc_cfg
[
'learning_rate'
]
=
cfg
[
'learning_rate'
]
*
net_d_reg_ratio
self
.
lr_scheduler
[
'disc'
]
=
build_lr_scheduler
(
disc_cfg
)
return
self
.
lr_scheduler
def
setup_optimizers
(
self
,
lr
,
cfg
):
for
opt_name
,
opt_cfg
in
cfg
.
items
():
if
opt_name
==
'optimG'
:
_lr
=
lr
[
'gen'
]
elif
opt_name
==
'optimD'
:
_lr
=
lr
[
'disc'
]
else
:
raise
ValueError
(
"opt name must be in ['optimG', optimD]"
)
cfg_
=
opt_cfg
.
copy
()
net_names
=
cfg_
.
pop
(
'net_names'
)
parameters
=
[]
for
net_name
in
net_names
:
parameters
+=
self
.
nets
[
net_name
].
parameters
()
self
.
optimizers
[
opt_name
]
=
build_optimizer
(
cfg_
,
_lr
,
parameters
)
return
self
.
optimizers
def
get_bare_model
(
self
,
net
):
"""Get bare model, especially under wrapping with DataParallel.
"""
if
isinstance
(
net
,
(
paddle
.
DataParallel
)):
net
=
net
.
_layers
return
net
def
model_ema
(
self
,
decay
=
0.999
):
net_g
=
self
.
get_bare_model
(
self
.
nets
[
'gen'
])
net_g_params
=
dict
(
net_g
.
named_parameters
())
neg_g_ema
=
self
.
get_bare_model
(
self
.
nets
[
'gen_ema'
])
net_g_ema_params
=
dict
(
neg_g_ema
.
named_parameters
())
for
k
in
net_g_ema_params
.
keys
():
net_g_ema_params
[
k
].
set_value
(
net_g_ema_params
[
k
]
*
(
decay
)
+
(
net_g_params
[
k
]
*
(
1
-
decay
)))
def
setup_input
(
self
,
input
):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Args:
input (dict): include the data itself and its metadata information.
"""
self
.
real_img
=
paddle
.
fluid
.
dygraph
.
to_variable
(
input
[
'A'
])
def
forward
(
self
):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
def
make_noise
(
self
,
batch
,
num_noise
):
if
num_noise
==
1
:
noises
=
paddle
.
randn
([
batch
,
self
.
num_style_feat
])
else
:
noises
=
[]
for
_
in
range
(
num_noise
):
noises
.
append
(
paddle
.
randn
([
batch
,
self
.
num_style_feat
]))
return
noises
def
mixing_noise
(
self
,
batch
,
prob
):
if
random
.
random
()
<
prob
:
return
self
.
make_noise
(
batch
,
2
)
else
:
return
[
self
.
make_noise
(
batch
,
1
)]
def
train_iter
(
self
,
optimizers
=
None
):
current_iter
=
self
.
current_iter
self
.
set_requires_grad
(
self
.
nets
[
'disc'
],
True
)
optimizers
[
'optimD'
].
clear_grad
()
batch
=
self
.
real_img
.
shape
[
0
]
noise
=
self
.
mixing_noise
(
batch
,
self
.
mixing_prob
)
fake_img
,
_
=
self
.
nets
[
'gen'
](
noise
)
self
.
visual_items
[
'real_img'
]
=
self
.
real_img
self
.
visual_items
[
'fake_img'
]
=
fake_img
fake_pred
=
self
.
nets
[
'disc'
](
fake_img
.
detach
())
real_pred
=
self
.
nets
[
'disc'
](
self
.
real_img
)
# wgan loss with softplus (logistic loss) for discriminator
l_d_total
=
0.
l_d
=
self
.
gan_criterion
(
real_pred
,
True
,
is_disc
=
True
)
+
self
.
gan_criterion
(
fake_pred
,
False
,
is_disc
=
True
)
self
.
losses
[
'l_d'
]
=
l_d
# In wgan, real_score should be positive and fake_score should be
# negative
self
.
losses
[
'real_score'
]
=
real_pred
.
detach
().
mean
()
self
.
losses
[
'fake_score'
]
=
fake_pred
.
detach
().
mean
()
l_d_total
+=
l_d
if
current_iter
%
self
.
disc_iters
==
0
:
self
.
real_img
.
stop_gradient
=
False
real_pred
=
self
.
nets
[
'disc'
](
self
.
real_img
)
l_d_r1
=
r1_penalty
(
real_pred
,
self
.
real_img
)
l_d_r1
=
(
self
.
r1_reg_weight
/
2
*
l_d_r1
*
self
.
disc_iters
+
0
*
real_pred
[
0
])
self
.
losses
[
'l_d_r1'
]
=
l_d_r1
.
detach
().
mean
()
l_d_total
+=
l_d_r1
l_d_total
.
backward
()
optimizers
[
'optimD'
].
step
()
self
.
set_requires_grad
(
self
.
nets
[
'disc'
],
False
)
optimizers
[
'optimG'
].
clear_grad
()
noise
=
self
.
mixing_noise
(
batch
,
self
.
mixing_prob
)
fake_img
,
_
=
self
.
nets
[
'gen'
](
noise
)
fake_pred
=
self
.
nets
[
'disc'
](
fake_img
)
# wgan loss with softplus (non-saturating loss) for generator
l_g_total
=
0.
l_g
=
self
.
gan_criterion
(
fake_pred
,
True
,
is_disc
=
False
)
self
.
losses
[
'l_g'
]
=
l_g
l_g_total
+=
l_g
if
current_iter
%
self
.
gen_iters
==
0
:
path_batch_size
=
max
(
1
,
batch
//
self
.
path_batch_shrink
)
noise
=
self
.
mixing_noise
(
path_batch_size
,
self
.
mixing_prob
)
fake_img
,
latents
=
self
.
nets
[
'gen'
](
noise
,
return_latents
=
True
)
l_g_path
,
path_lengths
,
self
.
mean_path_length
=
g_path_regularize
(
fake_img
,
latents
,
self
.
mean_path_length
)
l_g_path
=
(
self
.
path_reg_weight
*
self
.
gen_iters
*
l_g_path
+
0
*
fake_img
[
0
,
0
,
0
,
0
])
l_g_total
+=
l_g_path
self
.
losses
[
'l_g_path'
]
=
l_g_path
.
detach
().
mean
()
self
.
losses
[
'path_length'
]
=
path_lengths
l_g_total
.
backward
()
optimizers
[
'optimG'
].
step
()
# EMA
self
.
model_ema
(
decay
=
0.5
**
(
32
/
(
10
*
1000
)))
if
self
.
current_iter
%
self
.
visual_iters
:
sample_z
=
[
self
.
make_noise
(
1
,
1
)]
sample
,
_
=
self
.
nets
[
'gen_ema'
](
sample_z
)
self
.
visual_items
[
'fake_img_ema'
]
=
sample
self
.
current_iter
+=
1
ppgan/modules/equalized.py
浏览文件 @
530a6a8c
...
@@ -24,25 +24,30 @@ class EqualConv2D(nn.Layer):
...
@@ -24,25 +24,30 @@ class EqualConv2D(nn.Layer):
"""This convolutional layer class stabilizes the learning rate changes of its parameters.
"""This convolutional layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training.
Equalizing learning rate keeps the weights in the network at a similar scale during training.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_channel
,
out_channel
,
kernel_size
,
stride
=
1
,
padding
=
0
,
bias
=
True
in_channel
,
):
out_channel
,
kernel_size
,
stride
=
1
,
padding
=
0
,
bias
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
self
.
create_parameter
(
self
.
weight
=
self
.
create_parameter
(
(
out_channel
,
in_channel
,
kernel_size
,
kernel_size
),
default_initializer
=
nn
.
initializer
.
Normal
()
(
out_channel
,
in_channel
,
kernel_size
,
kernel_size
),
)
default_initializer
=
nn
.
initializer
.
Normal
()
)
self
.
scale
=
1
/
math
.
sqrt
(
in_channel
*
kernel_size
**
2
)
self
.
scale
=
1
/
math
.
sqrt
(
in_channel
*
(
kernel_size
*
kernel_size
)
)
self
.
stride
=
stride
self
.
stride
=
stride
self
.
padding
=
padding
self
.
padding
=
padding
if
bias
:
if
bias
:
self
.
bias
=
self
.
create_parameter
((
out_channel
,),
nn
.
initializer
.
Constant
(
0.0
))
self
.
bias
=
self
.
create_parameter
((
out_channel
,
),
nn
.
initializer
.
Constant
(
0.0
))
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
out
=
F
.
conv2d
(
out
=
F
.
conv2d
(
input
,
input
,
...
@@ -51,51 +56,57 @@ class EqualConv2D(nn.Layer):
...
@@ -51,51 +56,57 @@ class EqualConv2D(nn.Layer):
stride
=
self
.
stride
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
padding
=
self
.
padding
,
)
)
return
out
return
out
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
1
]
}
,
{
self
.
weight
.
shape
[
0
]
}
,"
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
1
]
}
,
{
self
.
weight
.
shape
[
0
]
}
,"
f
"
{
self
.
weight
.
shape
[
2
]
}
, stride=
{
self
.
stride
}
, padding=
{
self
.
padding
}
)"
f
"
{
self
.
weight
.
shape
[
2
]
}
, stride=
{
self
.
stride
}
, padding=
{
self
.
padding
}
)"
)
)
class
EqualLinear
(
nn
.
Layer
):
class
EqualLinear
(
nn
.
Layer
):
"""This linear layer class stabilizes the learning rate changes of its parameters.
"""This linear layer class stabilizes the learning rate changes of its parameters.
Equalizing learning rate keeps the weights in the network at a similar scale during training.
Equalizing learning rate keeps the weights in the network at a similar scale during training.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_dim
,
out_dim
,
bias
=
True
,
bias_init
=
0
,
lr_mul
=
1
,
activation
=
None
in_dim
,
):
out_dim
,
bias
=
True
,
bias_init
=
0
,
lr_mul
=
1
,
activation
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
self
.
create_parameter
((
in_dim
,
out_dim
),
default_initializer
=
nn
.
initializer
.
Normal
())
self
.
weight
=
self
.
create_parameter
(
self
.
weight
[:]
=
(
self
.
weight
/
lr_mul
).
detach
()
(
in_dim
,
out_dim
),
default_initializer
=
nn
.
initializer
.
Normal
())
self
.
weight
.
set_value
((
self
.
weight
/
lr_mul
))
if
bias
:
if
bias
:
self
.
bias
=
self
.
create_parameter
((
out_dim
,),
nn
.
initializer
.
Constant
(
bias_init
))
self
.
bias
=
self
.
create_parameter
(
(
out_dim
,
),
nn
.
initializer
.
Constant
(
bias_init
))
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
activation
=
activation
self
.
activation
=
activation
self
.
scale
=
(
1
/
math
.
sqrt
(
in_dim
))
*
lr_mul
self
.
scale
=
(
1
/
math
.
sqrt
(
in_dim
))
*
lr_mul
self
.
lr_mul
=
lr_mul
self
.
lr_mul
=
lr_mul
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
if
self
.
activation
:
if
self
.
activation
:
out
=
F
.
linear
(
input
,
self
.
weight
*
self
.
scale
)
out
=
F
.
linear
(
input
,
self
.
weight
*
self
.
scale
)
out
=
fused_leaky_relu
(
out
,
self
.
bias
*
self
.
lr_mul
)
out
=
fused_leaky_relu
(
out
,
self
.
bias
*
self
.
lr_mul
)
else
:
else
:
out
=
F
.
linear
(
out
=
F
.
linear
(
input
,
input
,
self
.
weight
*
self
.
scale
,
bias
=
self
.
bias
*
self
.
lr_mul
self
.
weight
*
self
.
scale
,
)
bias
=
self
.
bias
*
self
.
lr_mul
)
return
out
return
out
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
return
(
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
0
]
}
,
{
self
.
weight
.
shape
[
1
]
}
)"
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
0
]
}
,
{
self
.
weight
.
shape
[
1
]
}
)"
...
...
ppgan/modules/upfirdn2d.py
浏览文件 @
530a6a8c
...
@@ -15,37 +15,35 @@
...
@@ -15,37 +15,35 @@
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
def
upfirdn2d_native
(
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
pad_y0
,
pad_y1
):
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
_
,
channel
,
in_h
,
in_w
=
input
.
shape
input
=
input
.
reshape
((
-
1
,
in_h
,
in_w
,
1
))
input
=
input
.
reshape
((
-
1
,
in_h
,
in_w
,
1
))
_
,
in_h
,
in_w
,
minor
=
input
.
shape
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
reshape
((
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
))
out
=
input
.
reshape
((
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
))
out
=
out
.
transpose
((
0
,
1
,
3
,
5
,
2
,
4
))
out
=
out
.
transpose
((
0
,
1
,
3
,
5
,
2
,
4
))
out
=
out
.
reshape
((
-
1
,
1
,
1
,
1
))
out
=
out
.
reshape
((
-
1
,
1
,
1
,
1
))
out
=
F
.
pad
(
out
,
[
0
,
up_x
-
1
,
0
,
up_y
-
1
])
out
=
F
.
pad
(
out
,
[
0
,
up_x
-
1
,
0
,
up_y
-
1
])
out
=
out
.
reshape
((
-
1
,
in_h
,
in_w
,
minor
,
up_y
,
up_x
))
out
=
out
.
reshape
((
-
1
,
in_h
,
in_w
,
minor
,
up_y
,
up_x
))
out
=
out
.
transpose
((
0
,
3
,
1
,
4
,
2
,
5
))
out
=
out
.
transpose
((
0
,
3
,
1
,
4
,
2
,
5
))
out
=
out
.
reshape
((
-
1
,
minor
,
in_h
*
up_y
,
in_w
*
up_x
))
out
=
out
.
reshape
((
-
1
,
minor
,
in_h
*
up_y
,
in_w
*
up_x
))
out
=
F
.
pad
(
out
=
F
.
pad
(
out
,
[
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)]
out
,
[
max
(
pad_x0
,
0
),
)
max
(
pad_x1
,
0
),
out
=
out
[
max
(
pad_y0
,
0
),
:,:,
max
(
pad_y1
,
0
)])
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_y1
,
0
),
out
=
out
[:,
:,
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
3
]
-
max
(
-
pad_x1
,
0
),
max
(
-
pad_y0
,
0
):
out
.
shape
[
2
]
-
max
(
-
pad_y1
,
0
),
]
max
(
-
pad_x0
,
0
):
out
.
shape
[
3
]
-
max
(
-
pad_x1
,
0
),
]
out
=
out
.
reshape
((
out
=
out
.
reshape
(
[
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
]
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
]))
))
w
=
paddle
.
flip
(
kernel
,
[
0
,
1
]).
reshape
((
1
,
1
,
kernel_h
,
kernel_w
))
w
=
paddle
.
flip
(
kernel
,
[
0
,
1
]).
reshape
((
1
,
1
,
kernel_h
,
kernel_w
))
out
=
F
.
conv2d
(
out
,
w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
((
out
=
out
.
reshape
((
...
@@ -56,88 +54,95 @@ def upfirdn2d_native(
...
@@ -56,88 +54,95 @@ def upfirdn2d_native(
))
))
out
=
out
.
transpose
((
0
,
2
,
3
,
1
))
out
=
out
.
transpose
((
0
,
2
,
3
,
1
))
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
reshape
((
-
1
,
channel
,
out_h
,
out_w
))
return
out
.
reshape
((
-
1
,
channel
,
out_h
,
out_w
))
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
out
=
upfirdn2d_native
(
out
=
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
]
pad
[
0
],
pad
[
1
])
)
return
out
return
out
def
make_kernel
(
k
):
def
make_kernel
(
k
):
k
=
paddle
.
to_tensor
(
k
,
dtype
=
'float32'
)
k
=
paddle
.
to_tensor
(
k
,
dtype
=
'float32'
)
if
k
.
ndim
==
1
:
if
k
.
ndim
==
1
:
k
=
k
.
unsqueeze
(
0
)
*
k
.
unsqueeze
(
1
)
k
=
k
.
unsqueeze
(
0
)
*
k
.
unsqueeze
(
1
)
k
/=
k
.
sum
()
k
/=
k
.
sum
()
return
k
return
k
class
Upfirdn2dUpsample
(
nn
.
Layer
):
class
Upfirdn2dUpsample
(
nn
.
Layer
):
def
__init__
(
self
,
kernel
,
factor
=
2
):
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
super
().
__init__
()
self
.
factor
=
factor
self
.
factor
=
factor
kernel
=
make_kernel
(
kernel
)
*
(
factor
*
*
2
)
kernel
=
make_kernel
(
kernel
)
*
(
factor
*
factor
)
self
.
register_buffer
(
"kernel"
,
kernel
)
self
.
register_buffer
(
"kernel"
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
up
=
self
.
factor
,
down
=
1
,
pad
=
self
.
pad
)
out
=
upfirdn2d
(
input
,
self
.
kernel
,
up
=
self
.
factor
,
down
=
1
,
pad
=
self
.
pad
)
return
out
return
out
class
Upfirdn2dDownsample
(
nn
.
Layer
):
class
Upfirdn2dDownsample
(
nn
.
Layer
):
def
__init__
(
self
,
kernel
,
factor
=
2
):
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
super
().
__init__
()
self
.
factor
=
factor
self
.
factor
=
factor
kernel
=
make_kernel
(
kernel
)
kernel
=
make_kernel
(
kernel
)
self
.
register_buffer
(
"kernel"
,
kernel
)
self
.
register_buffer
(
"kernel"
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
up
=
1
,
down
=
self
.
factor
,
pad
=
self
.
pad
)
out
=
upfirdn2d
(
input
,
self
.
kernel
,
up
=
1
,
down
=
self
.
factor
,
pad
=
self
.
pad
)
return
out
return
out
class
Upfirdn2dBlur
(
nn
.
Layer
):
class
Upfirdn2dBlur
(
nn
.
Layer
):
def
__init__
(
self
,
kernel
,
pad
,
upsample_factor
=
1
):
def
__init__
(
self
,
kernel
,
pad
,
upsample_factor
=
1
):
super
().
__init__
()
super
().
__init__
()
kernel
=
make_kernel
(
kernel
)
kernel
=
make_kernel
(
kernel
)
if
upsample_factor
>
1
:
if
upsample_factor
>
1
:
kernel
=
kernel
*
(
upsample_factor
*
*
2
)
kernel
=
kernel
*
(
upsample_factor
*
upsample_factor
)
self
.
register_buffer
(
"kernel"
,
kernel
)
self
.
register_buffer
(
"kernel"
,
kernel
,
persistable
=
False
)
self
.
pad
=
pad
self
.
pad
=
pad
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
out
=
upfirdn2d
(
input
,
self
.
kernel
,
pad
=
self
.
pad
)
out
=
upfirdn2d
(
input
,
self
.
kernel
,
pad
=
self
.
pad
)
return
out
return
out
ppgan/utils/audio.py
浏览文件 @
530a6a8c
import
librosa
import
librosa.filters
import
numpy
as
np
import
numpy
as
np
from
scipy
import
signal
from
scipy
import
signal
from
scipy.io
import
wavfile
from
scipy.io
import
wavfile
from
paddle.utils
import
try_import
from
.audio_config
import
get_audio_config
from
.audio_config
import
get_audio_config
audio_config
=
get_audio_config
()
audio_config
=
get_audio_config
()
def
load_wav
(
path
,
sr
):
def
load_wav
(
path
,
sr
):
librosa
=
try_import
(
'librosa'
)
return
librosa
.
core
.
load
(
path
,
sr
=
sr
)[
0
]
return
librosa
.
core
.
load
(
path
,
sr
=
sr
)[
0
]
...
@@ -19,6 +19,7 @@ def save_wav(wav, path, sr):
...
@@ -19,6 +19,7 @@ def save_wav(wav, path, sr):
def
save_wavenet_wav
(
wav
,
path
,
sr
):
def
save_wavenet_wav
(
wav
,
path
,
sr
):
librosa
=
try_import
(
'librosa'
)
librosa
.
output
.
write_wav
(
path
,
wav
,
sr
=
sr
)
librosa
.
output
.
write_wav
(
path
,
wav
,
sr
=
sr
)
...
@@ -75,6 +76,7 @@ def _stft(y):
...
@@ -75,6 +76,7 @@ def _stft(y):
if
audio_config
.
use_lws
:
if
audio_config
.
use_lws
:
return
_lws_processor
(
audio_config
).
stft
(
y
).
T
return
_lws_processor
(
audio_config
).
stft
(
y
).
T
else
:
else
:
librosa
=
try_import
(
'librosa'
)
return
librosa
.
stft
(
y
=
y
,
return
librosa
.
stft
(
y
=
y
,
n_fft
=
audio_config
.
n_fft
,
n_fft
=
audio_config
.
n_fft
,
hop_length
=
get_hop_size
(),
hop_length
=
get_hop_size
(),
...
@@ -123,6 +125,7 @@ def _linear_to_mel(spectogram):
...
@@ -123,6 +125,7 @@ def _linear_to_mel(spectogram):
def
_build_mel_basis
():
def
_build_mel_basis
():
assert
audio_config
.
fmax
<=
audio_config
.
sample_rate
//
2
assert
audio_config
.
fmax
<=
audio_config
.
sample_rate
//
2
librosa
=
try_import
(
'librosa'
)
return
librosa
.
filters
.
mel
(
audio_config
.
sample_rate
,
return
librosa
.
filters
.
mel
(
audio_config
.
sample_rate
,
audio_config
.
n_fft
,
audio_config
.
n_fft
,
n_mels
=
audio_config
.
num_mels
,
n_mels
=
audio_config
.
num_mels
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录