Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
25932361
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
25932361
编写于
9月 04, 2019
作者:
L
lvmengsi
提交者:
GitHub
9月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix easy used (#3211)
* fix used
上级
50280922
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
109 addition
and
48 deletion
+109
-48
PaddleCV/PaddleGAN/network/base_network.py
PaddleCV/PaddleGAN/network/base_network.py
+2
-2
PaddleCV/PaddleGAN/scripts/infer_attgan.sh
PaddleCV/PaddleGAN/scripts/infer_attgan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/infer_cgan.sh
PaddleCV/PaddleGAN/scripts/infer_cgan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/infer_cyclegan.sh
PaddleCV/PaddleGAN/scripts/infer_cyclegan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/infer_dcgan.sh
PaddleCV/PaddleGAN/scripts/infer_dcgan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/infer_pix2pix.sh
PaddleCV/PaddleGAN/scripts/infer_pix2pix.sh
+1
-1
PaddleCV/PaddleGAN/scripts/infer_stargan.sh
PaddleCV/PaddleGAN/scripts/infer_stargan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/infer_stgan.sh
PaddleCV/PaddleGAN/scripts/infer_stgan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/run_stargan.sh
PaddleCV/PaddleGAN/scripts/run_stargan.sh
+1
-1
PaddleCV/PaddleGAN/trainer/AttGAN.py
PaddleCV/PaddleGAN/trainer/AttGAN.py
+15
-4
PaddleCV/PaddleGAN/trainer/CGAN.py
PaddleCV/PaddleGAN/trainer/CGAN.py
+11
-12
PaddleCV/PaddleGAN/trainer/CycleGAN.py
PaddleCV/PaddleGAN/trainer/CycleGAN.py
+14
-4
PaddleCV/PaddleGAN/trainer/DCGAN.py
PaddleCV/PaddleGAN/trainer/DCGAN.py
+15
-10
PaddleCV/PaddleGAN/trainer/Pix2pix.py
PaddleCV/PaddleGAN/trainer/Pix2pix.py
+14
-2
PaddleCV/PaddleGAN/trainer/STGAN.py
PaddleCV/PaddleGAN/trainer/STGAN.py
+15
-4
PaddleCV/PaddleGAN/trainer/StarGAN.py
PaddleCV/PaddleGAN/trainer/StarGAN.py
+15
-2
未找到文件。
PaddleCV/PaddleGAN/network/base_network.py
浏览文件 @
25932361
...
@@ -156,7 +156,7 @@ def conv2d(input,
...
@@ -156,7 +156,7 @@ def conv2d(input,
if
padding_type
==
"SAME"
:
if
padding_type
==
"SAME"
:
top_padding
,
bottom_padding
=
cal_padding
(
input
.
shape
[
2
],
stride
,
top_padding
,
bottom_padding
=
cal_padding
(
input
.
shape
[
2
],
stride
,
filter_size
)
filter_size
)
left_padding
,
right_padding
=
cal_padding
(
input
.
shape
[
2
],
stride
,
left_padding
,
right_padding
=
cal_padding
(
input
.
shape
[
3
],
stride
,
filter_size
)
filter_size
)
height_padding
=
bottom_padding
height_padding
=
bottom_padding
width_padding
=
right_padding
width_padding
=
right_padding
...
@@ -244,7 +244,7 @@ def deconv2d(input,
...
@@ -244,7 +244,7 @@ def deconv2d(input,
if
padding_type
==
"SAME"
:
if
padding_type
==
"SAME"
:
top_padding
,
bottom_padding
=
cal_padding
(
input
.
shape
[
2
],
stride
,
top_padding
,
bottom_padding
=
cal_padding
(
input
.
shape
[
2
],
stride
,
filter_size
)
filter_size
)
left_padding
,
right_padding
=
cal_padding
(
input
.
shape
[
2
],
stride
,
left_padding
,
right_padding
=
cal_padding
(
input
.
shape
[
3
],
stride
,
filter_size
)
filter_size
)
height_padding
=
bottom_padding
height_padding
=
bottom_padding
width_padding
=
right_padding
width_padding
=
right_padding
...
...
PaddleCV/PaddleGAN/scripts/infer_attgan.sh
浏览文件 @
25932361
python infer.py
--model_net
AttGAN
--init_model
output/checkpoints/119/
--dataset_dir
./data/celeba/
--image_size
128
python infer.py
--model_net
AttGAN
--init_model
output/checkpoints/119/
--dataset_dir
./data/celeba/
--image_size
128
--output
./infer_result/attgan/
PaddleCV/PaddleGAN/scripts/infer_cgan.sh
浏览文件 @
25932361
python infer.py
--model_net
CGAN
--init_model
./output/checkpoints/19/
--n_samples
32
--noise_size
100
python infer.py
--model_net
CGAN
--init_model
./output/checkpoints/19/
--n_samples
32
--noise_size
100
--output
./infer_result/c_gan/
PaddleCV/PaddleGAN/scripts/infer_cyclegan.sh
浏览文件 @
25932361
python infer.py
--init_model
output/checkpoints/199/
--dataset_dir
data/cityscapes/
--image_size
256
--n_samples
1
--crop_size
256
--input_style
B
--test_list
./data/cityscapes/testB.txt
--model_net
CycleGAN
--net_G
resnet_9block
--g_base_dims
32
python infer.py
--init_model
output/checkpoints/199/
--dataset_dir
data/cityscapes/
--image_size
256
--n_samples
1
--crop_size
256
--input_style
B
--test_list
./data/cityscapes/testB.txt
--model_net
CycleGAN
--net_G
resnet_9block
--g_base_dims
32
--output
./infer_result/cyclegan/
PaddleCV/PaddleGAN/scripts/infer_dcgan.sh
浏览文件 @
25932361
python infer.py
--model_net
DCGAN
--init_model
./output/checkpoints/19/
--n_samples
32
--noise_size
100
python infer.py
--model_net
DCGAN
--init_model
./output/checkpoints/19/
--n_samples
32
--noise_size
100
--output
./infer_result/dcgan/
PaddleCV/PaddleGAN/scripts/infer_pix2pix.sh
浏览文件 @
25932361
python infer.py
--init_model
output/checkpoints/199/
--image_size
256
--n_samples
1
--crop_size
256
--dataset_dir
data/cityscapes/
--model_net
Pix2pix
--net_G
unet_256
--test_list
data/cityscapes/testB.txt
python infer.py
--init_model
output/checkpoints/199/
--image_size
256
--n_samples
1
--crop_size
256
--dataset_dir
data/cityscapes/
--model_net
Pix2pix
--net_G
unet_256
--test_list
data/cityscapes/testB.txt
--output
./infer_result/pix2pix/
PaddleCV/PaddleGAN/scripts/infer_stargan.sh
浏览文件 @
25932361
python infer.py
--model_net
StarGAN
--init_model
./output/checkpoints/19/
--dataset_dir
./data/celeba/
--image_size
128
--c_dim
5
--selected_attrs
"Black_Hair,Blond_Hair,Brown_Hair,Male,Young"
python infer.py
--model_net
StarGAN
--init_model
./output/checkpoints/19/
--dataset_dir
./data/celeba/
--image_size
128
--c_dim
5
--selected_attrs
"Black_Hair,Blond_Hair,Brown_Hair,Male,Young"
--output
./infer_result/stargan/
PaddleCV/PaddleGAN/scripts/infer_stgan.sh
浏览文件 @
25932361
python infer.py
--model_net
STGAN
--init_model
./output/checkpoints/19/
--dataset_dir
./data/celeba/
--image_size
128
--use_gru
True
python infer.py
--model_net
STGAN
--init_model
./output/checkpoints/19/
--dataset_dir
./data/celeba/
--image_size
128
--use_gru
True
--output
./infer_result/stgan/
PaddleCV/PaddleGAN/scripts/run_stargan.sh
浏览文件 @
25932361
python train.py
--model_net
StarGAN
--dataset
celeba
--crop_size
178
--image_size
128
--train_list
./data/celeba/list_attr_celeba.txt
--batch_size
16
--epoch
20
>
log_out 2>log_err
python train.py
--model_net
StarGAN
--dataset
celeba
--crop_size
178
--image_size
128
--train_list
./data/celeba/list_attr_celeba.txt
--batch_size
16
--epoch
20
--gan_mode
wgan
>
log_out 2>log_err
PaddleCV/PaddleGAN/trainer/AttGAN.py
浏览文件 @
25932361
...
@@ -55,6 +55,9 @@ class GTrainer():
...
@@ -55,6 +55,9 @@ class GTrainer():
fluid
.
layers
.
square
(
fluid
.
layers
.
square
(
fluid
.
layers
.
elementwise_sub
(
fluid
.
layers
.
elementwise_sub
(
x
=
self
.
pred_fake
,
y
=
ones
)))
x
=
self
.
pred_fake
,
y
=
ones
)))
else
:
raise
NotImplementedError
(
"gan_mode {} is not support!"
.
format
(
cfg
.
gan_mode
))
self
.
g_loss_cls
=
fluid
.
layers
.
mean
(
self
.
g_loss_cls
=
fluid
.
layers
.
mean
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
self
.
cls_fake
,
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
self
.
cls_fake
,
...
@@ -126,6 +129,9 @@ class DTrainer():
...
@@ -126,6 +129,9 @@ class DTrainer():
cfg
=
cfg
,
cfg
=
cfg
,
name
=
"discriminator"
)
name
=
"discriminator"
)
self
.
d_loss
=
self
.
d_loss_real
+
self
.
d_loss_fake
+
1.0
*
self
.
d_loss_cls
+
cfg
.
lambda_gp
*
self
.
d_loss_gp
self
.
d_loss
=
self
.
d_loss_real
+
self
.
d_loss_fake
+
1.0
*
self
.
d_loss_cls
+
cfg
.
lambda_gp
*
self
.
d_loss_gp
else
:
raise
NotImplementedError
(
"gan_mode {} is not support!"
.
format
(
cfg
.
gan_mode
))
self
.
d_loss_real
.
persistable
=
True
self
.
d_loss_real
.
persistable
=
True
self
.
d_loss_fake
.
persistable
=
True
self
.
d_loss_fake
.
persistable
=
True
...
@@ -153,11 +159,11 @@ class DTrainer():
...
@@ -153,11 +159,11 @@ class DTrainer():
beta
=
fluid
.
layers
.
uniform_random_batch_size_like
(
beta
=
fluid
.
layers
.
uniform_random_batch_size_like
(
input
=
a
,
shape
=
a
.
shape
,
min
=
0.0
,
max
=
1.0
)
input
=
a
,
shape
=
a
.
shape
,
min
=
0.0
,
max
=
1.0
)
mean
=
fluid
.
layers
.
reduce_mean
(
mean
=
fluid
.
layers
.
reduce_mean
(
a
,
range
(
len
(
a
.
shape
)),
keep_dim
=
True
)
a
,
dim
=
list
(
range
(
len
(
a
.
shape
)
)),
keep_dim
=
True
)
input_sub_mean
=
fluid
.
layers
.
elementwise_sub
(
a
,
mean
,
axis
=
0
)
input_sub_mean
=
fluid
.
layers
.
elementwise_sub
(
a
,
mean
,
axis
=
0
)
var
=
fluid
.
layers
.
reduce_mean
(
var
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
square
(
input_sub_mean
),
fluid
.
layers
.
square
(
input_sub_mean
),
range
(
len
(
a
.
shape
)),
dim
=
list
(
range
(
len
(
a
.
shape
)
)),
keep_dim
=
True
)
keep_dim
=
True
)
b
=
beta
*
fluid
.
layers
.
sqrt
(
var
)
*
0.5
+
a
b
=
beta
*
fluid
.
layers
.
sqrt
(
var
)
*
0.5
+
a
shape
=
[
a
.
shape
[
0
]]
shape
=
[
a
.
shape
[
0
]]
...
@@ -297,7 +303,10 @@ class AttGAN(object):
...
@@ -297,7 +303,10 @@ class AttGAN(object):
# prepare environment
# prepare environment
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
py_reader
.
decorate_batch_generator
(
self
.
train_reader
,
places
=
place
)
py_reader
.
decorate_batch_generator
(
self
.
train_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
...
@@ -371,7 +380,9 @@ class AttGAN(object):
...
@@ -371,7 +380,9 @@ class AttGAN(object):
iterable
=
True
,
iterable
=
True
,
use_double_buffer
=
True
)
use_double_buffer
=
True
)
test_py_reader
.
decorate_batch_generator
(
test_py_reader
.
decorate_batch_generator
(
self
.
test_reader
,
places
=
place
)
self
.
test_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
test_program
=
test_gen_trainer
.
infer_program
test_program
=
test_gen_trainer
.
infer_program
utility
.
save_test_image
(
epoch_id
,
self
.
cfg
,
exe
,
place
,
utility
.
save_test_image
(
epoch_id
,
self
.
cfg
,
exe
,
place
,
...
...
PaddleCV/PaddleGAN/trainer/CGAN.py
浏览文件 @
25932361
...
@@ -122,8 +122,11 @@ class CGAN(object):
...
@@ -122,8 +122,11 @@ class CGAN(object):
d_trainer
.
program
).
with_data_parallel
(
d_trainer
.
program
).
with_data_parallel
(
loss_name
=
d_trainer
.
d_loss
.
name
,
build_strategy
=
build_strategy
)
loss_name
=
d_trainer
.
d_loss
.
name
,
build_strategy
=
build_strategy
)
if
self
.
cfg
.
run_test
:
image_path
=
os
.
path
.
join
(
self
.
cfg
.
output
,
'images'
)
if
not
os
.
path
.
exists
(
image_path
):
os
.
makedirs
(
image_path
)
t_time
=
0
t_time
=
0
losses
=
[[],
[]]
for
epoch_id
in
range
(
self
.
cfg
.
epoch
):
for
epoch_id
in
range
(
self
.
cfg
.
epoch
):
for
batch_id
,
data
in
enumerate
(
self
.
train_reader
()):
for
batch_id
,
data
in
enumerate
(
self
.
train_reader
()):
if
len
(
data
)
!=
self
.
cfg
.
batch_size
:
if
len
(
data
)
!=
self
.
cfg
.
batch_size
:
...
@@ -159,13 +162,12 @@ class CGAN(object):
...
@@ -159,13 +162,12 @@ class CGAN(object):
fetch_list
=
[
d_trainer
.
d_loss
])[
0
]
fetch_list
=
[
d_trainer
.
d_loss
])[
0
]
d_fake_loss
=
exe
.
run
(
d_trainer_program
,
d_fake_loss
=
exe
.
run
(
d_trainer_program
,
feed
=
{
feed
=
{
'img'
:
generate_image
,
'img'
:
generate_image
[
0
]
,
'condition'
:
condition_data
,
'condition'
:
condition_data
,
'label'
:
fake_label
'label'
:
fake_label
},
},
fetch_list
=
[
d_trainer
.
d_loss
])[
0
]
fetch_list
=
[
d_trainer
.
d_loss
])[
0
]
d_loss
=
d_real_loss
+
d_fake_loss
d_loss
=
d_real_loss
+
d_fake_loss
losses
[
1
].
append
(
d_loss
)
for
_
in
six
.
moves
.
xrange
(
self
.
cfg
.
num_generator_time
):
for
_
in
six
.
moves
.
xrange
(
self
.
cfg
.
num_generator_time
):
g_loss
=
exe
.
run
(
g_trainer_program
,
g_loss
=
exe
.
run
(
g_trainer_program
,
...
@@ -174,15 +176,16 @@ class CGAN(object):
...
@@ -174,15 +176,16 @@ class CGAN(object):
'condition'
:
condition_data
'condition'
:
condition_data
},
},
fetch_list
=
[
g_trainer
.
g_loss
])[
0
]
fetch_list
=
[
g_trainer
.
g_loss
])[
0
]
losses
[
0
].
append
(
g_loss
)
batch_time
=
time
.
time
()
-
s_time
batch_time
=
time
.
time
()
-
s_time
if
batch_id
%
self
.
cfg
.
print_freq
==
0
:
print
(
'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'
.
format
(
epoch_id
,
batch_id
,
d_loss
[
0
],
g_loss
[
0
],
batch_time
))
t_time
+=
batch_time
t_time
+=
batch_time
if
batch_id
%
self
.
cfg
.
print_freq
==
0
:
if
self
.
cfg
.
run_test
:
image_path
=
os
.
path
.
join
(
self
.
cfg
.
output
,
'images'
)
if
not
os
.
path
.
exists
(
image_path
):
os
.
makedirs
(
image_path
)
generate_const_image
=
exe
.
run
(
generate_const_image
=
exe
.
run
(
g_trainer
.
infer_program
,
g_trainer
.
infer_program
,
feed
=
{
'noise'
:
const_n
,
feed
=
{
'noise'
:
const_n
,
...
@@ -194,10 +197,6 @@ class CGAN(object):
...
@@ -194,10 +197,6 @@ class CGAN(object):
total_images
=
np
.
concatenate
(
total_images
=
np
.
concatenate
(
[
real_image
,
generate_image_reshape
])
[
real_image
,
generate_image_reshape
])
fig
=
utility
.
plot
(
total_images
)
fig
=
utility
.
plot
(
total_images
)
print
(
'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'
.
format
(
epoch_id
,
batch_id
,
d_loss
[
0
],
g_loss
[
0
],
batch_time
))
plt
.
title
(
'Epoch ID={}, Batch ID={}'
.
format
(
epoch_id
,
plt
.
title
(
'Epoch ID={}, Batch ID={}'
.
format
(
epoch_id
,
batch_id
))
batch_id
))
img_name
=
'{:04d}_{:04d}.png'
.
format
(
epoch_id
,
batch_id
)
img_name
=
'{:04d}_{:04d}.png'
.
format
(
epoch_id
,
batch_id
)
...
...
PaddleCV/PaddleGAN/trainer/CycleGAN.py
浏览文件 @
25932361
...
@@ -259,8 +259,14 @@ class CycleGAN(object):
...
@@ -259,8 +259,14 @@ class CycleGAN(object):
# prepare environment
# prepare environment
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
A_py_reader
.
decorate_batch_generator
(
self
.
A_reader
,
places
=
place
)
A_py_reader
.
decorate_batch_generator
(
B_py_reader
.
decorate_batch_generator
(
self
.
B_reader
,
places
=
place
)
self
.
A_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
B_py_reader
.
decorate_batch_generator
(
self
.
B_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
...
@@ -359,9 +365,13 @@ class CycleGAN(object):
...
@@ -359,9 +365,13 @@ class CycleGAN(object):
use_double_buffer
=
True
)
use_double_buffer
=
True
)
A_test_py_reader
.
decorate_batch_generator
(
A_test_py_reader
.
decorate_batch_generator
(
self
.
A_test_reader
,
places
=
place
)
self
.
A_test_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
B_test_py_reader
.
decorate_batch_generator
(
B_test_py_reader
.
decorate_batch_generator
(
self
.
B_test_reader
,
places
=
place
)
self
.
B_test_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
test_program
=
gen_trainer
.
infer_program
test_program
=
gen_trainer
.
infer_program
utility
.
save_test_image
(
utility
.
save_test_image
(
epoch_id
,
epoch_id
,
...
...
PaddleCV/PaddleGAN/trainer/DCGAN.py
浏览文件 @
25932361
...
@@ -95,7 +95,7 @@ class DCGAN(object):
...
@@ -95,7 +95,7 @@ class DCGAN(object):
d_trainer
=
DTrainer
(
img
,
label
,
self
.
cfg
)
d_trainer
=
DTrainer
(
img
,
label
,
self
.
cfg
)
# prepare enviorment
# prepare enviorment
place
=
fluid
.
CUDAPlace
(
0
)
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
...
@@ -118,6 +118,11 @@ class DCGAN(object):
...
@@ -118,6 +118,11 @@ class DCGAN(object):
d_trainer
.
program
).
with_data_parallel
(
d_trainer
.
program
).
with_data_parallel
(
loss_name
=
d_trainer
.
d_loss
.
name
,
build_strategy
=
build_strategy
)
loss_name
=
d_trainer
.
d_loss
.
name
,
build_strategy
=
build_strategy
)
if
self
.
cfg
.
run_test
:
image_path
=
os
.
path
.
join
(
self
.
cfg
.
output
,
'images'
)
if
not
os
.
path
.
exists
(
image_path
):
os
.
makedirs
(
image_path
)
t_time
=
0
t_time
=
0
for
epoch_id
in
range
(
self
.
cfg
.
epoch
):
for
epoch_id
in
range
(
self
.
cfg
.
epoch
):
for
batch_id
,
data
in
enumerate
(
self
.
train_reader
()):
for
batch_id
,
data
in
enumerate
(
self
.
train_reader
()):
...
@@ -148,7 +153,7 @@ class DCGAN(object):
...
@@ -148,7 +153,7 @@ class DCGAN(object):
fetch_list
=
[
d_trainer
.
d_loss
])[
0
]
fetch_list
=
[
d_trainer
.
d_loss
])[
0
]
d_fake_loss
=
exe
.
run
(
d_fake_loss
=
exe
.
run
(
d_trainer_program
,
d_trainer_program
,
feed
=
{
'img'
:
generate_image
,
feed
=
{
'img'
:
generate_image
[
0
]
,
'label'
:
fake_label
},
'label'
:
fake_label
},
fetch_list
=
[
d_trainer
.
d_loss
])[
0
]
fetch_list
=
[
d_trainer
.
d_loss
])[
0
]
d_loss
=
d_real_loss
+
d_fake_loss
d_loss
=
d_real_loss
+
d_fake_loss
...
@@ -164,12 +169,16 @@ class DCGAN(object):
...
@@ -164,12 +169,16 @@ class DCGAN(object):
fetch_list
=
[
g_trainer
.
g_loss
])[
0
]
fetch_list
=
[
g_trainer
.
g_loss
])[
0
]
batch_time
=
time
.
time
()
-
s_time
batch_time
=
time
.
time
()
-
s_time
t_time
+=
batch_time
if
batch_id
%
self
.
cfg
.
print_freq
==
0
:
if
batch_id
%
self
.
cfg
.
print_freq
==
0
:
image_path
=
os
.
path
.
join
(
self
.
cfg
.
output
,
'images'
)
print
(
if
not
os
.
path
.
exists
(
image_path
):
'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'
.
os
.
makedirs
(
image_path
)
format
(
epoch_id
,
batch_id
,
d_loss
[
0
],
g_loss
[
0
],
batch_time
))
t_time
+=
batch_time
if
self
.
cfg
.
run_test
:
generate_const_image
=
exe
.
run
(
generate_const_image
=
exe
.
run
(
g_trainer
.
infer_program
,
g_trainer
.
infer_program
,
feed
=
{
'noise'
:
const_n
},
feed
=
{
'noise'
:
const_n
},
...
@@ -181,10 +190,6 @@ class DCGAN(object):
...
@@ -181,10 +190,6 @@ class DCGAN(object):
[
real_image
,
generate_image_reshape
])
[
real_image
,
generate_image_reshape
])
fig
=
utility
.
plot
(
total_images
)
fig
=
utility
.
plot
(
total_images
)
print
(
'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'
.
format
(
epoch_id
,
batch_id
,
d_loss
[
0
],
g_loss
[
0
],
batch_time
))
plt
.
title
(
'Epoch ID={}, Batch ID={}'
.
format
(
epoch_id
,
plt
.
title
(
'Epoch ID={}, Batch ID={}'
.
format
(
epoch_id
,
batch_id
))
batch_id
))
img_name
=
'{:04d}_{:04d}.png'
.
format
(
epoch_id
,
batch_id
)
img_name
=
'{:04d}_{:04d}.png'
.
format
(
epoch_id
,
batch_id
)
...
...
PaddleCV/PaddleGAN/trainer/Pix2pix.py
浏览文件 @
25932361
...
@@ -56,6 +56,9 @@ class GTrainer():
...
@@ -56,6 +56,9 @@ class GTrainer():
self
.
g_loss_gan
=
fluid
.
layers
.
mean
(
self
.
g_loss_gan
=
fluid
.
layers
.
mean
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
x
=
self
.
pred
,
label
=
ones
))
x
=
self
.
pred
,
label
=
ones
))
else
:
raise
NotImplementedError
(
"gan_mode {} is not support!"
.
format
(
cfg
.
gan_mode
))
self
.
g_loss_L1
=
fluid
.
layers
.
reduce_mean
(
self
.
g_loss_L1
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
abs
(
fluid
.
layers
.
abs
(
...
@@ -140,6 +143,10 @@ class DTrainer():
...
@@ -140,6 +143,10 @@ class DTrainer():
self
.
d_loss_fake
=
fluid
.
layers
.
mean
(
self
.
d_loss_fake
=
fluid
.
layers
.
mean
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
x
=
self
.
pred_fake
,
label
=
zeros
))
x
=
self
.
pred_fake
,
label
=
zeros
))
else
:
raise
NotImplementedError
(
"gan_mode {} is not support!"
.
format
(
cfg
.
gan_mode
))
self
.
d_loss
=
0.5
*
(
self
.
d_loss_real
+
self
.
d_loss_fake
)
self
.
d_loss
=
0.5
*
(
self
.
d_loss_real
+
self
.
d_loss_fake
)
vars
=
[]
vars
=
[]
for
var
in
self
.
program
.
list_vars
():
for
var
in
self
.
program
.
list_vars
():
...
@@ -225,7 +232,10 @@ class Pix2pix(object):
...
@@ -225,7 +232,10 @@ class Pix2pix(object):
# prepare environment
# prepare environment
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
py_reader
.
decorate_batch_generator
(
self
.
train_reader
,
places
=
place
)
py_reader
.
decorate_batch_generator
(
self
.
train_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
...
@@ -299,7 +309,9 @@ class Pix2pix(object):
...
@@ -299,7 +309,9 @@ class Pix2pix(object):
iterable
=
True
,
iterable
=
True
,
use_double_buffer
=
True
)
use_double_buffer
=
True
)
test_py_reader
.
decorate_batch_generator
(
test_py_reader
.
decorate_batch_generator
(
self
.
test_reader
,
places
=
place
)
self
.
test_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
test_program
=
gen_trainer
.
infer_program
test_program
=
gen_trainer
.
infer_program
utility
.
save_test_image
(
utility
.
save_test_image
(
epoch_id
,
epoch_id
,
...
...
PaddleCV/PaddleGAN/trainer/STGAN.py
浏览文件 @
25932361
...
@@ -54,6 +54,9 @@ class GTrainer():
...
@@ -54,6 +54,9 @@ class GTrainer():
fluid
.
layers
.
square
(
fluid
.
layers
.
square
(
fluid
.
layers
.
elementwise_sub
(
fluid
.
layers
.
elementwise_sub
(
x
=
self
.
pred_fake
,
y
=
ones
)))
x
=
self
.
pred_fake
,
y
=
ones
)))
else
:
raise
NotImplementedError
(
"gan_mode {} is not support!"
.
format
(
cfg
.
gan_mode
))
self
.
g_loss_cls
=
fluid
.
layers
.
mean
(
self
.
g_loss_cls
=
fluid
.
layers
.
mean
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
self
.
cls_fake
,
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
self
.
cls_fake
,
...
@@ -128,6 +131,9 @@ class DTrainer():
...
@@ -128,6 +131,9 @@ class DTrainer():
cfg
=
cfg
,
cfg
=
cfg
,
name
=
"discriminator"
)
name
=
"discriminator"
)
self
.
d_loss
=
self
.
d_loss_real
+
self
.
d_loss_fake
+
1.0
*
self
.
d_loss_cls
+
cfg
.
lambda_gp
*
self
.
d_loss_gp
self
.
d_loss
=
self
.
d_loss_real
+
self
.
d_loss_fake
+
1.0
*
self
.
d_loss_cls
+
cfg
.
lambda_gp
*
self
.
d_loss_gp
else
:
raise
NotImplementedError
(
"gan_mode {} is not support!"
.
format
(
cfg
.
gan_mode
))
self
.
d_loss_real
.
persistable
=
True
self
.
d_loss_real
.
persistable
=
True
self
.
d_loss_fake
.
persistable
=
True
self
.
d_loss_fake
.
persistable
=
True
...
@@ -159,11 +165,11 @@ class DTrainer():
...
@@ -159,11 +165,11 @@ class DTrainer():
beta
=
fluid
.
layers
.
uniform_random_batch_size_like
(
beta
=
fluid
.
layers
.
uniform_random_batch_size_like
(
input
=
a
,
shape
=
a
.
shape
,
min
=
0.0
,
max
=
1.0
)
input
=
a
,
shape
=
a
.
shape
,
min
=
0.0
,
max
=
1.0
)
mean
=
fluid
.
layers
.
reduce_mean
(
mean
=
fluid
.
layers
.
reduce_mean
(
a
,
range
(
len
(
a
.
shape
)),
keep_dim
=
True
)
a
,
dim
=
list
(
range
(
len
(
a
.
shape
)
)),
keep_dim
=
True
)
input_sub_mean
=
fluid
.
layers
.
elementwise_sub
(
a
,
mean
,
axis
=
0
)
input_sub_mean
=
fluid
.
layers
.
elementwise_sub
(
a
,
mean
,
axis
=
0
)
var
=
fluid
.
layers
.
reduce_mean
(
var
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
square
(
input_sub_mean
),
fluid
.
layers
.
square
(
input_sub_mean
),
range
(
len
(
a
.
shape
)),
dim
=
list
(
range
(
len
(
a
.
shape
)
)),
keep_dim
=
True
)
keep_dim
=
True
)
b
=
beta
*
fluid
.
layers
.
sqrt
(
var
)
*
0.5
+
a
b
=
beta
*
fluid
.
layers
.
sqrt
(
var
)
*
0.5
+
a
shape
=
[
a
.
shape
[
0
]]
shape
=
[
a
.
shape
[
0
]]
...
@@ -308,7 +314,10 @@ class STGAN(object):
...
@@ -308,7 +314,10 @@ class STGAN(object):
# prepare environment
# prepare environment
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
py_reader
.
decorate_batch_generator
(
self
.
train_reader
,
places
=
place
)
py_reader
.
decorate_batch_generator
(
self
.
train_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
...
@@ -380,7 +389,9 @@ class STGAN(object):
...
@@ -380,7 +389,9 @@ class STGAN(object):
iterable
=
True
,
iterable
=
True
,
use_double_buffer
=
True
)
use_double_buffer
=
True
)
test_py_reader
.
decorate_batch_generator
(
test_py_reader
.
decorate_batch_generator
(
self
.
test_reader
,
places
=
place
)
self
.
test_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
test_program
=
test_gen_trainer
.
infer_program
test_program
=
test_gen_trainer
.
infer_program
utility
.
save_test_image
(
epoch_id
,
self
.
cfg
,
exe
,
place
,
utility
.
save_test_image
(
epoch_id
,
self
.
cfg
,
exe
,
place
,
test_program
,
test_gen_trainer
,
test_program
,
test_gen_trainer
,
...
...
PaddleCV/PaddleGAN/trainer/StarGAN.py
浏览文件 @
25932361
...
@@ -42,6 +42,10 @@ class GTrainer():
...
@@ -42,6 +42,10 @@ class GTrainer():
x
=
image_real
,
y
=
self
.
rec_img
)))
x
=
image_real
,
y
=
self
.
rec_img
)))
self
.
pred_fake
,
self
.
cls_fake
=
model
.
network_D
(
self
.
pred_fake
,
self
.
cls_fake
=
model
.
network_D
(
self
.
fake_img
,
cfg
,
name
=
"d_main"
)
self
.
fake_img
,
cfg
,
name
=
"d_main"
)
if
cfg
.
gan_mode
!=
'wgan'
:
raise
NotImplementedError
(
"gan_mode {} is not support! only support wgan"
.
format
(
cfg
.
gan_mode
))
#wgan
#wgan
self
.
g_loss_fake
=
-
1
*
fluid
.
layers
.
mean
(
self
.
pred_fake
)
self
.
g_loss_fake
=
-
1
*
fluid
.
layers
.
mean
(
self
.
pred_fake
)
...
@@ -104,6 +108,10 @@ class DTrainer():
...
@@ -104,6 +108,10 @@ class DTrainer():
self
.
d_loss_cls
=
fluid
.
layers
.
reduce_sum
(
self
.
d_loss_cls
=
fluid
.
layers
.
reduce_sum
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
self
.
cls_real
,
label_org
))
/
cfg
.
batch_size
self
.
cls_real
,
label_org
))
/
cfg
.
batch_size
if
cfg
.
gan_mode
!=
'wgan'
:
raise
NotImplementedError
(
"gan_mode {} is not support! only support wgan"
.
format
(
cfg
.
gan_mode
))
#wgan
#wgan
self
.
d_loss_fake
=
fluid
.
layers
.
mean
(
self
.
pred_fake
)
self
.
d_loss_fake
=
fluid
.
layers
.
mean
(
self
.
pred_fake
)
self
.
d_loss_real
=
-
1
*
fluid
.
layers
.
mean
(
self
.
pred_real
)
self
.
d_loss_real
=
-
1
*
fluid
.
layers
.
mean
(
self
.
pred_real
)
...
@@ -273,7 +281,10 @@ class StarGAN(object):
...
@@ -273,7 +281,10 @@ class StarGAN(object):
# prepare environment
# prepare environment
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
py_reader
.
decorate_batch_generator
(
self
.
train_reader
,
places
=
place
)
py_reader
.
decorate_batch_generator
(
self
.
train_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
...
@@ -346,7 +357,9 @@ class StarGAN(object):
...
@@ -346,7 +357,9 @@ class StarGAN(object):
iterable
=
True
,
iterable
=
True
,
use_double_buffer
=
True
)
use_double_buffer
=
True
)
test_py_reader
.
decorate_batch_generator
(
test_py_reader
.
decorate_batch_generator
(
self
.
test_reader
,
places
=
place
)
self
.
test_reader
,
places
=
fluid
.
cuda_places
()
if
self
.
cfg
.
use_gpu
else
fluid
.
cpu_places
())
test_program
=
gen_trainer
.
infer_program
test_program
=
gen_trainer
.
infer_program
utility
.
save_test_image
(
epoch_id
,
self
.
cfg
,
exe
,
place
,
utility
.
save_test_image
(
epoch_id
,
self
.
cfg
,
exe
,
place
,
test_program
,
gen_trainer
,
test_program
,
gen_trainer
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录