Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
cf153154
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cf153154
编写于
2月 06, 2021
作者:
L
LielinJiang
提交者:
GitHub
2月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix animegan (#158)
* adapt animegan * clean code Co-authored-by:
N
qingqing01
<
dangqingqing@baidu.com
>
上级
6cee870f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
100 addition
and
69 deletion
+100
-69
configs/animeganv2.yaml
configs/animeganv2.yaml
+26
-16
configs/animeganv2_pretrain.yaml
configs/animeganv2_pretrain.yaml
+20
-10
ppgan/datasets/animeganv2_dataset.py
ppgan/datasets/animeganv2_dataset.py
+15
-8
ppgan/models/animeganv2_model.py
ppgan/models/animeganv2_model.py
+39
-35
未找到文件。
configs/animeganv2.yaml
浏览文件 @
cf153154
epochs
:
30
output_dir
:
output_dir
pretrain_ckpt
:
output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight
:
300.
d_adv_weight
:
300.
con_weight
:
1.5
sty_weight
:
2.5
color_weight
:
10.
tv_weight
:
1.
model
:
name
:
AnimeGANV2Model
...
...
@@ -14,7 +7,16 @@ model:
name
:
AnimeGenerator
discriminator
:
name
:
AnimeDiscriminator
gan_criterion
:
name
:
GANLoss
gan_mode
:
lsgan
pretrain_ckpt
:
output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight
:
300.
d_adv_weight
:
300.
con_weight
:
1.5
sty_weight
:
2.5
color_weight
:
10.
tv_weight
:
1.
dataset
:
train
:
...
...
@@ -23,8 +25,6 @@ dataset:
batch_size
:
4
dataroot
:
data/animedataset
style
:
Hayao
phase
:
train
direction
:
AtoB
transform_real
:
-
name
:
Transpose
-
name
:
Normalize
...
...
@@ -63,15 +63,25 @@ dataset:
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
optimizer
:
name
:
Adam
beta1
:
0.5
lr_scheduler
:
name
:
linear
learning_rate
:
0.000
0
2
name
:
LinearDecay
learning_rate
:
0.0002
start_epoch
:
100
decay_epochs
:
100
# will get from real dataset
iters_per_epoch
:
1
optimizer
:
optimizer_G
:
name
:
Adam
net_names
:
-
netG
beta1
:
0.5
optimizer_D
:
name
:
Adam
net_names
:
-
netD
beta1
:
0.5
log_config
:
interval
:
100
...
...
configs/animeganv2_pretrain.yaml
浏览文件 @
cf153154
epochs
:
2
output_dir
:
output_dir
con_weight
:
1
pretrain_ckpt
:
null
model
:
name
:
AnimeGANV2PreTrainModel
...
...
@@ -9,7 +7,11 @@ model:
name
:
AnimeGenerator
discriminator
:
name
:
AnimeDiscriminator
gan_criterion
:
name
:
GANLoss
gan_mode
:
lsgan
con_weight
:
1
pretrain_ckpt
:
null
dataset
:
train
:
...
...
@@ -18,8 +20,6 @@ dataset:
batch_size
:
4
dataroot
:
data/animedataset
style
:
Hayao
phase
:
train
direction
:
AtoB
transform_real
:
-
name
:
Transpose
-
name
:
Normalize
...
...
@@ -57,15 +57,25 @@ dataset:
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
optimizer
:
name
:
Adam
beta1
:
0.5
lr_scheduler
:
name
:
linear
name
:
LinearDecay
learning_rate
:
0.0002
start_epoch
:
100
decay_epochs
:
100
# will get from real dataset
iters_per_epoch
:
1
optimizer
:
optimizer_G
:
name
:
Adam
net_names
:
-
netG
beta1
:
0.5
optimizer_D
:
name
:
Adam
net_names
:
-
netD
beta1
:
0.5
log_config
:
interval
:
100
...
...
ppgan/datasets/animeganv2_dataset.py
浏览文件 @
cf153154
...
...
@@ -13,8 +13,9 @@
#limitations under the License.
import
cv2
import
numpy
as
np
import
os.path
import
numpy
as
np
import
paddle
from
.base_dataset
import
BaseDataset
from
.image_folder
import
ImageFolder
...
...
@@ -23,21 +24,27 @@ from .transforms.builder import build_transforms
@
DATASETS
.
register
()
class
AnimeGANV2Dataset
(
Base
Dataset
):
class
AnimeGANV2Dataset
(
paddle
.
io
.
Dataset
):
"""
"""
def
__init__
(
self
,
cfg
):
def
__init__
(
self
,
dataroot
,
style
,
transform_real
=
None
,
transform_anime
=
None
,
transform_gray
=
None
):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
BaseDataset
.
__init__
(
self
,
cfg
)
self
.
style
=
cfg
.
style
# self.cfg = cfg
self
.
root
=
dataroot
self
.
style
=
style
self
.
transform_real
=
build_transforms
(
self
.
cfg
.
transform_real
)
self
.
transform_anime
=
build_transforms
(
self
.
cfg
.
transform_anime
)
self
.
transform_gray
=
build_transforms
(
self
.
cfg
.
transform_gray
)
self
.
transform_real
=
build_transforms
(
transform_real
)
self
.
transform_anime
=
build_transforms
(
transform_anime
)
self
.
transform_gray
=
build_transforms
(
transform_gray
)
self
.
real_root
=
os
.
path
.
join
(
self
.
root
,
'train_photo'
)
self
.
anime_root
=
os
.
path
.
join
(
self
.
root
,
f
'
{
self
.
style
}
'
,
'style'
)
...
...
ppgan/models/animeganv2_model.py
浏览文件 @
cf153154
...
...
@@ -13,62 +13,66 @@
#limitations under the License.
import
paddle
from
paddle
import
nn
import
paddle.nn
as
nn
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.generators.builder
import
build_generator
from
.discriminators.builder
import
build_discriminator
from
.criterions
.gan_loss
import
GANLoss
from
.criterions
import
build_criterion
from
..modules.caffevgg
import
CaffeVGG19
from
..solver
import
build_optimizer
from
..modules.init
import
init_weights
from
..utils.filesystem
import
load
@
MODELS
.
register
()
class
AnimeGANV2Model
(
BaseModel
):
def
__init__
(
self
,
cfg
):
def
__init__
(
self
,
generator
,
discriminator
=
None
,
gan_criterion
=
None
,
pretrain_ckpt
=
None
,
g_adv_weight
=
300.
,
d_adv_weight
=
300.
,
con_weight
=
1.5
,
sty_weight
=
2.5
,
color_weight
=
10.
,
tv_weight
=
1.
):
"""Initialize the AnimeGANV2 class.
Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
"""
super
(
AnimeGANV2Model
,
self
).
__init__
(
cfg
)
super
(
AnimeGANV2Model
,
self
).
__init__
()
self
.
g_adv_weight
=
g_adv_weight
self
.
d_adv_weight
=
d_adv_weight
self
.
con_weight
=
con_weight
self
.
sty_weight
=
sty_weight
self
.
color_weight
=
color_weight
self
.
tv_weight
=
tv_weight
# define networks (both generator and discriminator)
self
.
nets
[
'netG'
]
=
build_generator
(
cfg
.
model
.
generator
)
self
.
nets
[
'netG'
]
=
build_generator
(
generator
)
init_weights
(
self
.
nets
[
'netG'
])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if
self
.
is_train
:
self
.
nets
[
'netD'
]
=
build_discriminator
(
cfg
.
model
.
discriminator
)
self
.
nets
[
'netD'
]
=
build_discriminator
(
discriminator
)
init_weights
(
self
.
nets
[
'netD'
])
self
.
pretrained
=
CaffeVGG19
()
self
.
losses
=
{}
# define loss functions
self
.
criterionGAN
=
GANLoss
(
cfg
.
model
.
gan_mode
)
self
.
criterionGAN
=
build_criterion
(
gan_criterion
)
self
.
criterionL1
=
nn
.
L1Loss
()
self
.
criterionHub
=
nn
.
SmoothL1Loss
()
# build optimizers
self
.
build_lr_scheduler
()
self
.
optimizers
[
'optimizer_G'
]
=
build_optimizer
(
cfg
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
nets
[
'netG'
].
parameters
())
self
.
optimizers
[
'optimizer_D'
]
=
build_optimizer
(
cfg
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
nets
[
'netD'
].
parameters
())
if
self
.
cfg
.
pretrain_ckpt
:
state_dicts
=
load
(
self
.
cfg
.
pretrain_ckpt
)
if
pretrain_ckpt
:
state_dicts
=
load
(
pretrain_ckpt
)
self
.
nets
[
'netG'
].
set_state_dict
(
state_dicts
[
'netG'
])
print
(
'Load pretrained generator from'
,
self
.
cfg
.
pretrain_ckpt
)
print
(
'Load pretrained generator from'
,
pretrain_ckpt
)
def
set_input
(
self
,
input
):
def
set
up
_input
(
self
,
input
):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
"""
...
...
@@ -152,13 +156,13 @@ class AnimeGANV2Model(BaseModel):
fake_logit
=
self
.
nets
[
'netD'
](
self
.
fake
.
detach
())
smooth_logit
=
self
.
nets
[
'netD'
](
self
.
smooth_gray
)
d_real_loss
=
(
self
.
cfg
.
d_adv_weight
*
1.2
*
d_real_loss
=
(
self
.
d_adv_weight
*
1.2
*
self
.
criterionGAN
(
real_logit
,
True
))
d_gray_loss
=
(
self
.
cfg
.
d_adv_weight
*
1.2
*
d_gray_loss
=
(
self
.
d_adv_weight
*
1.2
*
self
.
criterionGAN
(
gray_logit
,
False
))
d_fake_loss
=
(
self
.
cfg
.
d_adv_weight
*
1.2
*
d_fake_loss
=
(
self
.
d_adv_weight
*
1.2
*
self
.
criterionGAN
(
fake_logit
,
False
))
d_blur_loss
=
(
self
.
cfg
.
d_adv_weight
*
0.8
*
d_blur_loss
=
(
self
.
d_adv_weight
*
0.8
*
self
.
criterionGAN
(
smooth_logit
,
False
))
self
.
loss_D
=
d_real_loss
+
d_gray_loss
+
d_fake_loss
+
d_blur_loss
...
...
@@ -175,11 +179,11 @@ class AnimeGANV2Model(BaseModel):
fake_logit
=
self
.
nets
[
'netD'
](
self
.
fake
)
c_loss
,
s_loss
=
self
.
con_sty_loss
(
self
.
real
,
self
.
anime_gray
,
self
.
fake
)
c_loss
=
self
.
c
fg
.
c
on_weight
*
c_loss
s_loss
=
self
.
cfg
.
sty_weight
*
s_loss
tv_loss
=
self
.
cfg
.
tv_weight
*
self
.
variation_loss
(
self
.
fake
)
col_loss
=
self
.
c
fg
.
c
olor_weight
*
self
.
color_loss
(
self
.
real
,
self
.
fake
)
g_loss
=
(
self
.
cfg
.
g_adv_weight
*
self
.
criterionGAN
(
fake_logit
,
True
))
c_loss
=
self
.
con_weight
*
c_loss
s_loss
=
self
.
sty_weight
*
s_loss
tv_loss
=
self
.
tv_weight
*
self
.
variation_loss
(
self
.
fake
)
col_loss
=
self
.
color_weight
*
self
.
color_loss
(
self
.
real
,
self
.
fake
)
g_loss
=
(
self
.
g_adv_weight
*
self
.
criterionGAN
(
fake_logit
,
True
))
self
.
loss_G
=
c_loss
+
s_loss
+
col_loss
+
g_loss
+
tv_loss
...
...
@@ -191,7 +195,7 @@ class AnimeGANV2Model(BaseModel):
self
.
losses
[
'col_loss'
]
=
col_loss
self
.
losses
[
'tv_loss'
]
=
tv_loss
def
optimize_parameters
(
self
):
def
train_iter
(
self
,
optimizers
=
None
):
# compute fake images: G(A)
self
.
forward
()
...
...
@@ -212,11 +216,11 @@ class AnimeGANV2PreTrainModel(AnimeGANV2Model):
real_feature_map
=
self
.
pretrained
(
self
.
real
)
fake_feature_map
=
self
.
pretrained
(
self
.
fake
)
init_c_loss
=
self
.
criterionL1
(
real_feature_map
,
fake_feature_map
)
loss
=
self
.
c
fg
.
c
on_weight
*
init_c_loss
loss
=
self
.
con_weight
*
init_c_loss
loss
.
backward
()
self
.
losses
[
'init_c_loss'
]
=
init_c_loss
def
optimize_parameters
(
self
):
def
train_iter
(
self
,
optimizers
=
None
):
self
.
forward
()
# update G
self
.
optimizers
[
'optimizer_G'
].
clear_grad
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录