Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
fccd4d47
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fccd4d47
编写于
5月 06, 2021
作者:
L
LielinJiang
提交者:
GitHub
5月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix attributed error for DataParallel (#303)
上级
fa53795c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
134 addition
and
46 deletion
+134
-46
ppgan/models/gan_model.py
ppgan/models/gan_model.py
+6
-1
ppgan/models/starganv2_model.py
ppgan/models/starganv2_model.py
+128
-45
未找到文件。
ppgan/models/gan_model.py
浏览文件 @
fccd4d47
...
...
@@ -91,7 +91,12 @@ class GANModel(BaseModel):
self
.
n_class
=
0
batch_size
=
self
.
D_real_inputs
[
0
].
shape
[
0
]
self
.
G_inputs
=
self
.
nets
[
'netG'
].
random_inputs
(
batch_size
)
if
isinstance
(
self
.
nets
[
'netG'
],
paddle
.
DataParallel
):
self
.
G_inputs
=
self
.
nets
[
'netG'
].
_layers
.
random_inputs
(
batch_size
)
else
:
self
.
G_inputs
=
self
.
nets
[
'netG'
].
random_inputs
(
batch_size
)
if
not
isinstance
(
self
.
G_inputs
,
(
list
,
tuple
)):
self
.
G_inputs
=
[
self
.
G_inputs
]
...
...
ppgan/models/starganv2_model.py
浏览文件 @
fccd4d47
...
...
@@ -25,20 +25,29 @@ def translate_using_reference(nets, w_hpf, x_src, x_ref, y_ref):
for
_
in
range
(
N
):
s_ref_lists
.
append
(
s_ref_list
)
s_ref_list
=
paddle
.
stack
(
s_ref_lists
,
axis
=
1
)
s_ref_list
=
paddle
.
reshape
(
s_ref_list
,
(
s_ref_list
.
shape
[
0
],
s_ref_list
.
shape
[
1
],
s_ref_list
.
shape
[
3
]))
s_ref_list
=
paddle
.
reshape
(
s_ref_list
,
(
s_ref_list
.
shape
[
0
],
s_ref_list
.
shape
[
1
],
s_ref_list
.
shape
[
3
]))
x_concat
=
[
x_src_with_wb
]
for
i
,
s_ref
in
enumerate
(
s_ref_list
):
x_fake
=
nets
[
'generator'
](
x_src
,
s_ref
,
masks
=
masks
)
x_fake_with_ref
=
paddle
.
concat
([
x_ref
[
i
:
i
+
1
],
x_fake
],
axis
=
0
)
x_fake_with_ref
=
paddle
.
concat
([
x_ref
[
i
:
i
+
1
],
x_fake
],
axis
=
0
)
x_concat
+=
[
x_fake_with_ref
]
x_concat
=
paddle
.
concat
(
x_concat
,
axis
=
0
)
img
=
tensor2img
(
make_grid
(
x_concat
,
nrow
=
N
+
1
,
range
=
(
0
,
1
)))
img
=
tensor2img
(
make_grid
(
x_concat
,
nrow
=
N
+
1
,
range
=
(
0
,
1
)))
del
x_concat
return
img
def
compute_d_loss
(
nets
,
lambda_reg
,
x_real
,
y_org
,
y_trg
,
z_trg
=
None
,
x_ref
=
None
,
masks
=
None
):
def
compute_d_loss
(
nets
,
lambda_reg
,
x_real
,
y_org
,
y_trg
,
z_trg
=
None
,
x_ref
=
None
,
masks
=
None
):
assert
(
z_trg
is
None
)
!=
(
x_ref
is
None
)
# with real images
x_real
.
stop_gradient
=
False
...
...
@@ -58,9 +67,11 @@ def compute_d_loss(nets, lambda_reg, x_real, y_org, y_trg, z_trg=None, x_ref=Non
loss_fake
=
adv_loss
(
out
,
0
)
loss
=
loss_real
+
loss_fake
+
lambda_reg
*
loss_reg
return
loss
,
{
'real'
:
loss_real
.
numpy
(),
'fake'
:
loss_fake
.
numpy
(),
'reg'
:
loss_reg
.
numpy
()}
return
loss
,
{
'real'
:
loss_real
.
numpy
(),
'fake'
:
loss_fake
.
numpy
(),
'reg'
:
loss_reg
.
numpy
()
}
def
adv_loss
(
logits
,
target
):
...
...
@@ -73,21 +84,29 @@ def adv_loss(logits, target):
def
r1_reg
(
d_out
,
x_in
):
# zero-centered gradient penalty for real images
batch_size
=
x_in
.
shape
[
0
]
grad_dout
=
paddle
.
grad
(
outputs
=
d_out
.
sum
(),
inputs
=
x_in
,
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
grad_dout
=
paddle
.
grad
(
outputs
=
d_out
.
sum
(),
inputs
=
x_in
,
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
grad_dout2
=
grad_dout
.
pow
(
2
)
assert
(
grad_dout2
.
shape
==
x_in
.
shape
)
assert
(
grad_dout2
.
shape
==
x_in
.
shape
)
reg
=
0.5
*
paddle
.
reshape
(
grad_dout2
,
(
batch_size
,
-
1
)).
sum
(
1
).
mean
(
0
)
return
reg
def
soft_update
(
source
,
target
,
beta
=
1.0
):
assert
0.0
<=
beta
<=
1.0
if
isinstance
(
source
,
paddle
.
DataParallel
):
source
=
source
.
_layers
target_model_map
=
dict
(
target
.
named_parameters
())
for
param_name
,
source_param
in
source
.
named_parameters
():
target_param
=
target_model_map
[
param_name
]
target_param
.
set_value
(
beta
*
source_param
+
(
1.0
-
beta
)
*
target_param
)
target_param
.
set_value
(
beta
*
source_param
+
(
1.0
-
beta
)
*
target_param
)
def
dump_model
(
model
):
params
=
{}
...
...
@@ -97,7 +116,17 @@ def dump_model(model):
return
params
def
compute_g_loss
(
nets
,
w_hpf
,
lambda_sty
,
lambda_ds
,
lambda_cyc
,
x_real
,
y_org
,
y_trg
,
z_trgs
=
None
,
x_refs
=
None
,
masks
=
None
):
def
compute_g_loss
(
nets
,
w_hpf
,
lambda_sty
,
lambda_ds
,
lambda_cyc
,
x_real
,
y_org
,
y_trg
,
z_trgs
=
None
,
x_refs
=
None
,
masks
=
None
):
assert
(
z_trgs
is
None
)
!=
(
x_refs
is
None
)
if
z_trgs
is
not
None
:
z_trg
,
z_trg2
=
z_trgs
...
...
@@ -127,17 +156,23 @@ def compute_g_loss(nets, w_hpf, lambda_sty, lambda_ds, lambda_cyc, x_real, y_org
loss_ds
=
paddle
.
mean
(
paddle
.
abs
(
x_fake
-
x_fake2
))
# cycle-consistency loss
masks
=
nets
[
'fan'
].
get_heatmap
(
x_fake
)
if
w_hpf
>
0
else
None
if
isinstance
(
nets
[
'fan'
],
paddle
.
DataParallel
):
masks
=
nets
[
'fan'
].
_layers
.
get_heatmap
(
x_fake
)
if
w_hpf
>
0
else
None
else
:
masks
=
nets
[
'fan'
].
get_heatmap
(
x_fake
)
if
w_hpf
>
0
else
None
s_org
=
nets
[
'style_encoder'
](
x_real
,
y_org
)
x_rec
=
nets
[
'generator'
](
x_fake
,
s_org
,
masks
=
masks
)
loss_cyc
=
paddle
.
mean
(
paddle
.
abs
(
x_rec
-
x_real
))
loss
=
loss_adv
+
lambda_sty
*
loss_sty
\
-
lambda_ds
*
loss_ds
+
lambda_cyc
*
loss_cyc
return
loss
,
{
'adv'
:
loss_adv
.
numpy
(),
'sty'
:
loss_sty
.
numpy
(),
'ds:'
:
loss_ds
.
numpy
(),
'cyc'
:
loss_cyc
.
numpy
()}
return
loss
,
{
'adv'
:
loss_adv
.
numpy
(),
'sty'
:
loss_sty
.
numpy
(),
'ds:'
:
loss_ds
.
numpy
(),
'cyc'
:
loss_cyc
.
numpy
()
}
def
he_init
(
module
):
...
...
@@ -154,7 +189,7 @@ def he_init(module):
@
MODELS
.
register
()
class
StarGANv2Model
(
BaseModel
):
def
__init__
(
self
,
self
,
generator
,
style
=
None
,
mapping
=
None
,
...
...
@@ -195,7 +230,7 @@ class StarGANv2Model(BaseModel):
# remember the initial value of ds weight
self
.
initial_lambda_ds
=
self
.
lambda_ds
def
setup_input
(
self
,
input
):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
...
...
@@ -206,8 +241,10 @@ class StarGANv2Model(BaseModel):
"""
pass
self
.
input
=
input
self
.
input
[
'z_trg'
]
=
paddle
.
randn
((
input
[
'src'
].
shape
[
0
],
self
.
latent_dim
))
self
.
input
[
'z_trg2'
]
=
paddle
.
randn
((
input
[
'src'
].
shape
[
0
],
self
.
latent_dim
))
self
.
input
[
'z_trg'
]
=
paddle
.
randn
(
(
input
[
'src'
].
shape
[
0
],
self
.
latent_dim
))
self
.
input
[
'z_trg2'
]
=
paddle
.
randn
(
(
input
[
'src'
].
shape
[
0
],
self
.
latent_dim
))
def
forward
(
self
):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
...
...
@@ -220,50 +257,89 @@ class StarGANv2Model(BaseModel):
def
train_iter
(
self
,
optimizers
=
None
):
#TODO
x_real
,
y_org
=
self
.
input
[
'src'
],
self
.
input
[
'src_cls'
]
x_ref
,
x_ref2
,
y_trg
=
self
.
input
[
'ref'
],
self
.
input
[
'ref2'
],
self
.
input
[
'ref_cls'
]
x_ref
,
x_ref2
,
y_trg
=
self
.
input
[
'ref'
],
self
.
input
[
'ref2'
],
self
.
input
[
'ref_cls'
]
z_trg
,
z_trg2
=
self
.
input
[
'z_trg'
],
self
.
input
[
'z_trg2'
]
masks
=
self
.
nets
[
'fan'
].
get_heatmap
(
x_real
)
if
self
.
w_hpf
>
0
else
None
if
isinstance
(
self
.
nets
[
'fan'
],
paddle
.
DataParallel
):
masks
=
self
.
nets
[
'fan'
].
_layers
.
get_heatmap
(
x_real
)
if
self
.
w_hpf
>
0
else
None
else
:
masks
=
self
.
nets
[
'fan'
].
get_heatmap
(
x_real
)
if
self
.
w_hpf
>
0
else
None
# train the discriminator
d_loss
,
d_losses_latent
=
compute_d_loss
(
self
.
nets
,
self
.
lambda_reg
,
x_real
,
y_org
,
y_trg
,
z_trg
=
z_trg
,
masks
=
masks
)
d_loss
,
d_losses_latent
=
compute_d_loss
(
self
.
nets
,
self
.
lambda_reg
,
x_real
,
y_org
,
y_trg
,
z_trg
=
z_trg
,
masks
=
masks
)
self
.
_reset_grad
(
optimizers
)
d_loss
.
backward
()
optimizers
[
'discriminator'
].
minimize
(
d_loss
)
d_loss
,
d_losses_ref
=
compute_d_loss
(
self
.
nets
,
self
.
lambda_reg
,
x_real
,
y_org
,
y_trg
,
x_ref
=
x_ref
,
masks
=
masks
)
d_loss
,
d_losses_ref
=
compute_d_loss
(
self
.
nets
,
self
.
lambda_reg
,
x_real
,
y_org
,
y_trg
,
x_ref
=
x_ref
,
masks
=
masks
)
self
.
_reset_grad
(
optimizers
)
d_loss
.
backward
()
optimizers
[
'discriminator'
].
step
()
# train the generator
g_loss
,
g_losses_latent
=
compute_g_loss
(
self
.
nets
,
self
.
w_hpf
,
self
.
lambda_sty
,
self
.
lambda_ds
,
self
.
lambda_cyc
,
x_real
,
y_org
,
y_trg
,
z_trgs
=
[
z_trg
,
z_trg2
],
masks
=
masks
)
g_loss
,
g_losses_latent
=
compute_g_loss
(
self
.
nets
,
self
.
w_hpf
,
self
.
lambda_sty
,
self
.
lambda_ds
,
self
.
lambda_cyc
,
x_real
,
y_org
,
y_trg
,
z_trgs
=
[
z_trg
,
z_trg2
],
masks
=
masks
)
self
.
_reset_grad
(
optimizers
)
g_loss
.
backward
()
optimizers
[
'generator'
].
step
()
optimizers
[
'mapping_network'
].
step
()
optimizers
[
'style_encoder'
].
step
()
g_loss
,
g_losses_ref
=
compute_g_loss
(
self
.
nets
,
self
.
w_hpf
,
self
.
lambda_sty
,
self
.
lambda_ds
,
self
.
lambda_cyc
,
x_real
,
y_org
,
y_trg
,
x_refs
=
[
x_ref
,
x_ref2
],
masks
=
masks
)
g_loss
,
g_losses_ref
=
compute_g_loss
(
self
.
nets
,
self
.
w_hpf
,
self
.
lambda_sty
,
self
.
lambda_ds
,
self
.
lambda_cyc
,
x_real
,
y_org
,
y_trg
,
x_refs
=
[
x_ref
,
x_ref2
],
masks
=
masks
)
self
.
_reset_grad
(
optimizers
)
g_loss
.
backward
()
optimizers
[
'generator'
].
step
()
# compute moving average of network parameters
soft_update
(
self
.
nets
[
'generator'
],
self
.
nets_ema
[
'generator'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'mapping_network'
],
self
.
nets_ema
[
'mapping_network'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'style_encoder'
],
self
.
nets_ema
[
'style_encoder'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'generator'
],
self
.
nets_ema
[
'generator'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'mapping_network'
],
self
.
nets_ema
[
'mapping_network'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'style_encoder'
],
self
.
nets_ema
[
'style_encoder'
],
beta
=
0.999
)
# decay weight for diversity sensitive loss
if
self
.
lambda_ds
>
0
:
self
.
lambda_ds
-=
(
self
.
initial_lambda_ds
/
self
.
total_iter
)
for
loss
,
prefix
in
zip
([
d_losses_latent
,
d_losses_ref
,
g_losses_latent
,
g_losses_ref
],
[
'D/latent_'
,
'D/ref_'
,
'G/latent_'
,
'G/ref_'
]):
for
loss
,
prefix
in
zip
(
[
d_losses_latent
,
d_losses_ref
,
g_losses_latent
,
g_losses_ref
],
[
'D/latent_'
,
'D/ref_'
,
'G/latent_'
,
'G/ref_'
]):
for
key
,
value
in
loss
.
items
():
self
.
losses
[
prefix
+
key
]
=
value
self
.
losses
[
'G/lambda_ds'
]
=
self
.
lambda_ds
...
...
@@ -273,17 +349,24 @@ class StarGANv2Model(BaseModel):
#TODO
self
.
nets_ema
[
'generator'
].
eval
()
self
.
nets_ema
[
'style_encoder'
].
eval
()
soft_update
(
self
.
nets
[
'generator'
],
self
.
nets_ema
[
'generator'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'mapping_network'
],
self
.
nets_ema
[
'mapping_network'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'style_encoder'
],
self
.
nets_ema
[
'style_encoder'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'generator'
],
self
.
nets_ema
[
'generator'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'mapping_network'
],
self
.
nets_ema
[
'mapping_network'
],
beta
=
0.999
)
soft_update
(
self
.
nets
[
'style_encoder'
],
self
.
nets_ema
[
'style_encoder'
],
beta
=
0.999
)
src_img
=
self
.
input
[
'src'
]
ref_img
=
self
.
input
[
'ref'
]
ref_label
=
self
.
input
[
'ref_cls'
]
with
paddle
.
no_grad
():
img
=
translate_using_reference
(
self
.
nets_ema
,
self
.
w_hpf
,
paddle
.
to_tensor
(
src_img
).
astype
(
'float32'
),
paddle
.
to_tensor
(
ref_img
).
astype
(
'float32'
),
paddle
.
to_tensor
(
ref_label
).
astype
(
'float32'
))
img
=
translate_using_reference
(
self
.
nets_ema
,
self
.
w_hpf
,
paddle
.
to_tensor
(
src_img
).
astype
(
'float32'
),
paddle
.
to_tensor
(
ref_img
).
astype
(
'float32'
),
paddle
.
to_tensor
(
ref_label
).
astype
(
'float32'
))
self
.
visual_items
[
'reference'
]
=
img
self
.
nets_ema
[
'generator'
].
train
()
self
.
nets_ema
[
'style_encoder'
].
train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录