Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
f7b53f07
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f7b53f07
编写于
12月 17, 2020
作者:
L
LielinJiang
提交者:
GitHub
12月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
adapt wgan (#128)
上级
7bba9f8d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
114 addition
and
73 deletion
+114
-73
configs/wgan_mnist.yaml
configs/wgan_mnist.yaml
+25
-15
ppgan/datasets/common_vision_dataset.py
ppgan/datasets/common_vision_dataset.py
+21
-12
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+16
-4
ppgan/models/base_model.py
ppgan/models/base_model.py
+10
-2
ppgan/models/gan_model.py
ppgan/models/gan_model.py
+42
-40
未找到文件。
configs/wgan_mnist.yaml
浏览文件 @
f7b53f07
...
...
@@ -15,18 +15,20 @@ model:
n_layers
:
3
input_nc
:
1
norm_type
:
instance
gan_mode
:
wgan
n_critic
:
5
gan_criterion
:
name
:
GANLoss
gan_mode
:
wgan
params
:
disc_iters
:
5
visual_interval
:
500
dataset
:
train
:
name
:
CommonVisionDataset
class_name
:
MNIST
dataroot
:
None
dataset_name
:
MNIST
num_workers
:
4
batch_size
:
64
mode
:
train
return_cls
:
False
return_label
:
False
transforms
:
-
name
:
Normalize
mean
:
[
127.5
]
...
...
@@ -34,28 +36,36 @@ dataset:
keys
:
[
image
]
test
:
name
:
CommonVisionDataset
class_name
:
MNIST
dataroot
:
None
dataset_name
:
MNIST
num_workers
:
0
batch_size
:
64
mode
:
test
return_label
:
False
transforms
:
-
name
:
Normalize
mean
:
[
127.5
]
std
:
[
127.5
]
keys
:
[
image
]
return_cls
:
False
optimizer
:
name
:
Adam
beta1
:
0.5
lr_scheduler
:
name
:
linear
name
:
LinearDecay
learning_rate
:
0.0002
start_epoch
:
100
decay_epochs
:
100
# will get from real dataset
iters_per_epoch
:
1
optimizer
:
optimizer_G
:
name
:
Adam
net_names
:
-
netG
beta1
:
0.5
optimizer_D
:
name
:
Adam
net_names
:
-
netD
beta1
:
0.5
log_config
:
interval
:
100
...
...
ppgan/datasets/common_vision_dataset.py
浏览文件 @
f7b53f07
...
...
@@ -21,29 +21,38 @@ from .transforms.builder import build_transforms
@
DATASETS
.
register
()
class
CommonVisionDataset
(
Base
Dataset
):
class
CommonVisionDataset
(
paddle
.
io
.
Dataset
):
"""
Dataset for using paddle vision default datasets
Dataset for using paddle vision default datasets
, such as mnist, flowers.
"""
def
__init__
(
self
,
cfg
):
def
__init__
(
self
,
dataset_name
,
transforms
=
None
,
return_label
=
True
,
params
=
None
):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
dataset_name (str): return a dataset from paddle.vision.datasets by this option.
transforms (list[dict]): A sequence of data transforms config.
return_label (bool): whether to retuan a label of a sample.
params (dict): paramters of paddle.vision.datasets.
"""
super
(
CommonVisionDataset
,
self
).
__init__
(
cfg
)
super
(
CommonVisionDataset
,
self
).
__init__
()
dataset_cls
=
getattr
(
paddle
.
vision
.
datasets
,
cfg
.
pop
(
'class_name'
)
)
transform
=
build_transforms
(
cfg
.
pop
(
'transforms'
,
None
)
)
self
.
return_
cls
=
cfg
.
pop
(
'return_cls'
,
True
)
dataset_cls
=
getattr
(
paddle
.
vision
.
datasets
,
dataset_name
)
transform
=
build_transforms
(
transforms
)
self
.
return_
label
=
return_label
param_dict
=
{}
param_names
=
list
(
dataset_cls
.
__init__
.
__code__
.
co_varnames
)
if
'transform'
in
param_names
:
param_dict
[
'transform'
]
=
transform
for
name
in
param_names
:
if
name
in
cfg
:
param_dict
[
name
]
=
cfg
.
get
(
name
)
if
params
is
not
None
:
for
name
in
param_names
:
if
name
in
params
:
param_dict
[
name
]
=
params
[
name
]
self
.
dataset
=
dataset_cls
(
**
param_dict
)
...
...
@@ -53,7 +62,7 @@ class CommonVisionDataset(BaseDataset):
if
isinstance
(
return_list
,
(
tuple
,
list
)):
if
len
(
return_list
)
==
2
:
return_dict
[
'img'
]
=
return_list
[
0
]
if
self
.
return_
cls
:
if
self
.
return_
label
:
return_dict
[
'class_id'
]
=
np
.
asarray
(
return_list
[
1
])
else
:
return_dict
[
'img'
]
=
return_list
[
0
]
...
...
ppgan/engine/trainer.py
浏览文件 @
f7b53f07
...
...
@@ -211,12 +211,24 @@ class Trainer:
current_paths
=
self
.
model
.
get_image_paths
()
current_visuals
=
self
.
model
.
get_current_visuals
()
for
j
in
range
(
len
(
current_paths
)):
short_path
=
os
.
path
.
basename
(
current_paths
[
j
])
basename
=
os
.
path
.
splitext
(
short_path
)[
0
]
if
len
(
current_visuals
)
>
0
and
list
(
current_visuals
.
values
())[
0
].
shape
==
4
:
num_samples
=
list
(
current_visuals
.
values
())[
0
].
shape
[
0
]
else
:
num_samples
=
1
for
j
in
range
(
num_samples
):
if
j
<
len
(
current_paths
):
short_path
=
os
.
path
.
basename
(
current_paths
[
j
])
basename
=
os
.
path
.
splitext
(
short_path
)[
0
]
else
:
basename
=
'{:04d}_{:04d}'
.
format
(
i
,
j
)
for
k
,
img_tensor
in
current_visuals
.
items
():
name
=
'%s_%s'
%
(
basename
,
k
)
visual_results
.
update
({
name
:
img_tensor
[
j
]})
if
len
(
img_tensor
.
shape
)
==
4
:
visual_results
.
update
({
name
:
img_tensor
[
j
]})
else
:
visual_results
.
update
({
name
:
img_tensor
})
self
.
visual
(
'visual_test'
,
visual_results
=
visual_results
,
...
...
ppgan/models/base_model.py
浏览文件 @
f7b53f07
...
...
@@ -50,7 +50,7 @@ class BaseModel(ABC):
# save checkpoint (model.nets) \/
"""
def
__init__
(
self
):
def
__init__
(
self
,
params
=
None
):
"""Initialize the BaseModel class.
When creating your custom class, you need to implement your own initialization.
...
...
@@ -62,7 +62,13 @@ class BaseModel(ABC):
-- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example.
Args:
params (dict): Hyper params for train or test. Default: None.
"""
self
.
params
=
params
self
.
is_train
=
True
if
self
.
params
is
None
else
self
.
params
.
get
(
'is_train'
,
True
)
self
.
nets
=
OrderedDict
()
self
.
optimizers
=
OrderedDict
()
...
...
@@ -149,7 +155,9 @@ class BaseModel(ABC):
def
get_image_paths
(
self
):
""" Return image paths that are used to load current data"""
return
self
.
image_paths
if
hasattr
(
self
,
'image_paths'
):
return
self
.
image_paths
return
[]
def
get_current_visuals
(
self
):
"""Return visualization images."""
...
...
ppgan/models/gan_model.py
浏览文件 @
f7b53f07
...
...
@@ -19,7 +19,7 @@ from .base_model import BaseModel
from
.builder
import
MODELS
from
.generators.builder
import
build_generator
from
.discriminators.builder
import
build_discriminator
from
.criterions.
gan_loss
import
GANLoss
from
.criterions.
builder
import
build_criterion
from
..solver
import
build_optimizer
from
..modules.init
import
init_weights
...
...
@@ -32,44 +32,46 @@ class GANModel(BaseModel):
vanilla GAN paper: https://arxiv.org/abs/1406.2661
"""
def
__init__
(
self
,
cfg
):
def
__init__
(
self
,
generator
,
discriminator
=
None
,
gan_criterion
=
None
,
params
=
None
):
"""Initialize the GAN Model class.
Parameters:
cfg (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
gan_criterion (dict): config of gan criterion.
params (dict): hyper params for train or test. Default: None.
"""
super
(
GANModel
,
self
).
__init__
(
cfg
)
self
.
step
=
0
self
.
n_critic
=
cfg
.
model
.
get
(
'n_critic'
,
1
)
self
.
visual_interval
=
cfg
.
log_config
.
visiual_interval
self
.
samples_every_row
=
cfg
.
model
.
get
(
'samples_every_row'
,
8
)
# define networks (both generator and discriminator)
self
.
nets
[
'netG'
]
=
build_generator
(
cfg
.
model
.
generator
)
super
(
GANModel
,
self
).
__init__
(
params
)
self
.
iter
=
0
self
.
disc_iters
=
1
if
self
.
params
is
None
else
self
.
params
.
get
(
'disc_iters'
,
1
)
self
.
disc_start_iters
=
(
0
if
self
.
params
is
None
else
self
.
params
.
get
(
'disc_start_iters'
,
0
))
self
.
samples_every_row
=
(
8
if
self
.
params
is
None
else
self
.
params
.
get
(
'samples_every_row'
,
8
))
self
.
visual_interval
=
(
500
if
self
.
params
is
None
else
self
.
params
.
get
(
'visual_interval'
,
500
))
# define generator
self
.
nets
[
'netG'
]
=
build_generator
(
generator
)
init_weights
(
self
.
nets
[
'netG'
])
# define a discriminator
if
self
.
is_train
:
self
.
nets
[
'netD'
]
=
build_discriminator
(
cfg
.
model
.
discriminator
)
init_weights
(
self
.
nets
[
'netD'
])
if
discriminator
is
not
None
:
self
.
nets
[
'netD'
]
=
build_discriminator
(
discriminator
)
init_weights
(
self
.
nets
[
'netD'
])
if
self
.
is_train
:
self
.
losses
=
{}
# define loss functions
self
.
criterionGAN
=
GANLoss
(
cfg
.
model
.
gan_mode
)
# build optimizers
self
.
build_lr_scheduler
()
self
.
optimizers
[
'optimizer_G'
]
=
build_optimizer
(
cfg
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
nets
[
'netG'
].
parameters
())
self
.
optimizers
[
'optimizer_D'
]
=
build_optimizer
(
cfg
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
nets
[
'netD'
].
parameters
())
def
set_input
(
self
,
input
):
if
gan_criterion
:
self
.
criterionGAN
=
build_criterion
(
gan_criterion
)
def
setup_input
(
self
,
input
):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
...
...
@@ -131,7 +133,7 @@ class GANModel(BaseModel):
self
.
loss_D_real
=
self
.
criterionGAN
(
pred_real
,
True
,
True
)
# combine loss and calculate gradients
if
self
.
c
fg
.
model
.
gan_mode
in
[
'vanilla'
,
'lsgan'
]:
if
self
.
c
riterionGAN
.
gan_mode
in
[
'vanilla'
,
'lsgan'
]:
self
.
loss_D
=
self
.
loss_D
+
(
self
.
loss_D_fake
+
self
.
loss_D_real
)
*
0.5
else
:
...
...
@@ -159,34 +161,34 @@ class GANModel(BaseModel):
self
.
losses
[
'G_adv_loss'
]
=
self
.
loss_G_GAN
def
optimize_parameters
(
self
):
def
train_iter
(
self
,
optimizers
=
None
):
# compute fake images: G(imgs)
self
.
forward
()
# update D
self
.
set_requires_grad
(
self
.
nets
[
'netD'
],
True
)
self
.
optimizers
[
'optimizer_D'
].
clear_grad
()
optimizers
[
'optimizer_D'
].
clear_grad
()
self
.
backward_D
()
self
.
optimizers
[
'optimizer_D'
].
step
()
optimizers
[
'optimizer_D'
].
step
()
self
.
set_requires_grad
(
self
.
nets
[
'netD'
],
False
)
# weight clip
if
self
.
c
fg
.
model
.
gan_mode
==
'wgan'
:
if
self
.
c
riterionGAN
.
gan_mode
==
'wgan'
:
with
paddle
.
no_grad
():
for
p
in
self
.
nets
[
'netD'
].
parameters
():
p
[:]
=
p
.
clip
(
-
0.01
,
0.01
)
if
self
.
step
%
self
.
n_critic
==
0
:
if
self
.
iter
>
self
.
disc_start_iters
and
self
.
iter
%
self
.
disc_iters
==
0
:
# update G
self
.
optimizers
[
'optimizer_G'
].
clear_grad
()
optimizers
[
'optimizer_G'
].
clear_grad
()
self
.
backward_G
()
self
.
optimizers
[
'optimizer_G'
].
step
()
optimizers
[
'optimizer_G'
].
step
()
if
self
.
step
%
self
.
visual_interval
==
0
:
if
self
.
iter
%
self
.
visual_interval
==
0
:
with
paddle
.
no_grad
():
self
.
visual_items
[
'fixed_generated_imgs'
]
=
make_grid
(
self
.
nets
[
'netG'
](
*
self
.
G_fixed_inputs
),
self
.
samples_every_row
)
self
.
step
+=
1
self
.
iter
+=
1
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录