Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
4a93b001
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4a93b001
编写于
7月 25, 2019
作者:
L
lvmengsi
提交者:
GitHub
7月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update gan (#2871)
* refine gan
上级
7b1a5565
变更
16
展开全部
隐藏空白更改
内联
并排
Showing
16 changed file
with
263 addition
and
306 deletion
+263
-306
PaddleCV/PaddleGAN/data_reader.py
PaddleCV/PaddleGAN/data_reader.py
+167
-200
PaddleCV/PaddleGAN/infer.py
PaddleCV/PaddleGAN/infer.py
+13
-17
PaddleCV/PaddleGAN/network/AttGAN_network.py
PaddleCV/PaddleGAN/network/AttGAN_network.py
+20
-20
PaddleCV/PaddleGAN/scripts/infer_attgan.sh
PaddleCV/PaddleGAN/scripts/infer_attgan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/run_attgan.sh
PaddleCV/PaddleGAN/scripts/run_attgan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/run_cyclegan.sh
PaddleCV/PaddleGAN/scripts/run_cyclegan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/run_pix2pix.sh
PaddleCV/PaddleGAN/scripts/run_pix2pix.sh
+1
-1
PaddleCV/PaddleGAN/scripts/run_stargan.sh
PaddleCV/PaddleGAN/scripts/run_stargan.sh
+1
-1
PaddleCV/PaddleGAN/scripts/run_stgan.sh
PaddleCV/PaddleGAN/scripts/run_stgan.sh
+1
-1
PaddleCV/PaddleGAN/train.py
PaddleCV/PaddleGAN/train.py
+22
-35
PaddleCV/PaddleGAN/trainer/AttGAN.py
PaddleCV/PaddleGAN/trainer/AttGAN.py
+0
-2
PaddleCV/PaddleGAN/trainer/STGAN.py
PaddleCV/PaddleGAN/trainer/STGAN.py
+0
-2
PaddleCV/PaddleGAN/trainer/StarGAN.py
PaddleCV/PaddleGAN/trainer/StarGAN.py
+0
-2
PaddleCV/PaddleGAN/trainer/__init__.py
PaddleCV/PaddleGAN/trainer/__init__.py
+8
-0
PaddleCV/PaddleGAN/util/config.py
PaddleCV/PaddleGAN/util/config.py
+1
-1
PaddleCV/PaddleGAN/util/utility.py
PaddleCV/PaddleGAN/util/utility.py
+26
-21
未找到文件。
PaddleCV/PaddleGAN/data_reader.py
浏览文件 @
4a93b001
此差异已折叠。
点击以展开。
PaddleCV/PaddleGAN/infer.py
浏览文件 @
4a93b001
...
...
@@ -54,8 +54,8 @@ add_arg('image_size', int, 128, "image size")
add_arg
(
'selected_attrs'
,
str
,
"Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young"
,
"the attributes we selected to change"
)
add_arg
(
'
batch_size
'
,
int
,
16
,
"batch size when test"
)
add_arg
(
'test_list'
,
str
,
"./data/celeba/
test_
list_attr_celeba.txt"
,
"the test list file"
)
add_arg
(
'
n_samples
'
,
int
,
16
,
"batch size when test"
)
add_arg
(
'test_list'
,
str
,
"./data/celeba/list_attr_celeba.txt"
,
"the test list file"
)
add_arg
(
'dataset_dir'
,
str
,
"./data/celeba/"
,
"the dataset directory to be infered"
)
add_arg
(
'n_layers'
,
int
,
5
,
"default layers in generotor"
)
add_arg
(
'gru_n_layers'
,
int
,
4
,
"default layers of GRU in generotor"
)
...
...
@@ -149,11 +149,9 @@ def infer(args):
test_reader
=
celeba_reader_creator
(
image_dir
=
args
.
dataset_dir
,
list_filename
=
args
.
test_list
,
batch_size
=
args
.
batch_size
,
drop_last
=
False
,
args
=
args
)
reader_test
=
test_reader
.
get_test_reader
(
args
,
shuffle
=
False
,
return_name
=
True
)
args
=
args
,
mode
=
"VAL"
)
reader_test
=
test_reader
.
make_reader
(
return_name
=
True
)
for
data
in
zip
(
reader_test
()):
real_img
,
label_org
,
name
=
data
[
0
]
print
(
"read {}"
.
format
(
name
))
...
...
@@ -199,11 +197,9 @@ def infer(args):
test_reader
=
celeba_reader_creator
(
image_dir
=
args
.
dataset_dir
,
list_filename
=
args
.
test_list
,
batch_size
=
args
.
batch_size
,
drop_last
=
False
,
args
=
args
)
reader_test
=
test_reader
.
get_test_reader
(
args
,
shuffle
=
False
,
return_name
=
True
)
args
=
args
,
mode
=
"VAL"
)
reader_test
=
test_reader
.
make_reader
(
return_name
=
True
)
for
data
in
zip
(
reader_test
()):
real_img
,
label_org
,
name
=
data
[
0
]
print
(
"read {}"
.
format
(
name
))
...
...
@@ -256,9 +252,9 @@ def infer(args):
elif
args
.
model_net
==
'CGAN'
:
noise_data
=
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
[
args
.
batch_size
,
args
.
noise_size
]).
astype
(
'float32'
)
size
=
[
args
.
n_samples
,
args
.
noise_size
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
9
,
size
=
[
args
.
batch_size
,
1
]).
astype
(
'float32'
)
0
,
9
,
size
=
[
args
.
n_samples
,
1
]).
astype
(
'float32'
)
noise_tensor
=
fluid
.
LoDTensor
()
conditions_tensor
=
fluid
.
LoDTensor
()
noise_tensor
.
set
(
noise_data
,
place
)
...
...
@@ -267,7 +263,7 @@ def infer(args):
fetch_list
=
[
fake
.
name
],
feed
=
{
"noise"
:
noise_tensor
,
"conditions"
:
conditions_tensor
})[
0
]
fake_image
=
np
.
reshape
(
fake_temp
,
(
args
.
batch_size
,
-
1
))
fake_image
=
np
.
reshape
(
fake_temp
,
(
args
.
n_samples
,
-
1
))
fig
=
utility
.
plot
(
fake_image
)
plt
.
savefig
(
...
...
@@ -277,12 +273,12 @@ def infer(args):
elif
args
.
model_net
==
'DCGAN'
:
noise_data
=
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
[
args
.
batch_size
,
args
.
noise_size
]).
astype
(
'float32'
)
size
=
[
args
.
n_samples
,
args
.
noise_size
]).
astype
(
'float32'
)
noise_tensor
=
fluid
.
LoDTensor
()
noise_tensor
.
set
(
noise_data
,
place
)
fake_temp
=
exe
.
run
(
fetch_list
=
[
fake
.
name
],
feed
=
{
"noise"
:
noise_tensor
})[
0
]
fake_image
=
np
.
reshape
(
fake_temp
,
(
args
.
batch_size
,
-
1
))
fake_image
=
np
.
reshape
(
fake_temp
,
(
args
.
n_samples
,
-
1
))
fig
=
utility
.
plot
(
fake_image
)
plt
.
savefig
(
...
...
PaddleCV/PaddleGAN/network/AttGAN_network.py
浏览文件 @
4a93b001
...
...
@@ -71,10 +71,10 @@ class AttGAN_model(object):
d
=
min
(
dim
*
2
**
i
,
MAX_DIM
)
#SAME padding
z
=
conv2d
(
z
,
d
,
4
,
2
,
input
=
z
,
num_filters
=
d
,
filter_size
=
4
,
stride
=
2
,
padding_type
=
'SAME'
,
norm
=
'batch_norm'
,
activation_fn
=
'leaky_relu'
,
...
...
@@ -104,10 +104,10 @@ class AttGAN_model(object):
if
i
<
n_layers
-
1
:
d
=
min
(
dim
*
2
**
(
n_layers
-
1
-
i
),
MAX_DIM
)
z
=
deconv2d
(
z
,
d
,
4
,
2
,
input
=
z
,
num_filters
=
d
,
filter_size
=
4
,
stride
=
2
,
padding_type
=
'SAME'
,
name
=
name
+
str
(
i
),
norm
=
'batch_norm'
,
...
...
@@ -121,10 +121,10 @@ class AttGAN_model(object):
z
=
self
.
concat
(
z
,
a
)
else
:
x
=
z
=
deconv2d
(
z
,
3
,
4
,
2
,
input
=
z
,
num_filters
=
3
,
filter_size
=
4
,
stride
=
2
,
padding_type
=
'SAME'
,
name
=
name
+
str
(
i
),
activation_fn
=
'tanh'
,
...
...
@@ -146,10 +146,10 @@ class AttGAN_model(object):
for
i
in
range
(
n_layers
):
d
=
min
(
dim
*
2
**
i
,
MAX_DIM
)
y
=
conv2d
(
y
,
d
,
4
,
2
,
input
=
y
,
num_filters
=
d
,
filter_size
=
4
,
stride
=
2
,
norm
=
norm
,
padding
=
1
,
activation_fn
=
'leaky_relu'
,
...
...
@@ -159,8 +159,8 @@ class AttGAN_model(object):
initial
=
'kaiming'
)
logit_gan
=
linear
(
y
,
fc_dim
,
input
=
y
,
output_size
=
fc_dim
,
activation_fn
=
'relu'
,
name
=
name
+
'fc_adv_1'
,
initial
=
'kaiming'
)
...
...
@@ -168,8 +168,8 @@ class AttGAN_model(object):
logit_gan
,
1
,
name
=
name
+
'fc_adv_2'
,
initial
=
'kaiming'
)
logit_att
=
linear
(
y
,
fc_dim
,
input
=
y
,
output_size
=
fc_dim
,
activation_fn
=
'relu'
,
name
=
name
+
'fc_cls_1'
,
initial
=
'kaiming'
)
...
...
PaddleCV/PaddleGAN/scripts/infer_attgan.sh
浏览文件 @
4a93b001
python infer.py
--model_net
AttGAN
--init_model
output/checkpoints/1
9
9/
--dataset_dir
"data/celeba/"
--image_size
128
python infer.py
--model_net
AttGAN
--init_model
output/checkpoints/1
1
9/
--dataset_dir
"data/celeba/"
--image_size
128
PaddleCV/PaddleGAN/scripts/run_attgan.sh
浏览文件 @
4a93b001
python train.py
--model_net
AttGAN
--dataset
celeba
--crop_size
170
--image_size
128
--train_list
./data/celeba/list_attr_celeba.txt
--test_list
./data/celeba/test_list_attr_celeba.txt
--gan_mode
wgan
--batch_size
32
--print_freq
1
--num_discriminator_time
5
--epoch
9
0
>
log_out 2>log_err
python train.py
--model_net
AttGAN
--dataset
celeba
--crop_size
170
--image_size
128
--train_list
./data/celeba/list_attr_celeba.txt
--gan_mode
wgan
--batch_size
32
--print_freq
1
--num_discriminator_time
5
--epoch
12
0
>
log_out 2>log_err
PaddleCV/PaddleGAN/scripts/run_cyclegan.sh
浏览文件 @
4a93b001
python train.py
--model_net
CycleGAN
--dataset
cityscapes
--batch_size
1
--net_G
resnet_9block
--g_base_dim
32
--net_D
basic
--norm_type
batch_norm
--epoch
200
--
load
_size
286
--crop_size
256
--crop_type
Random
>
log_out 2>log_err
python train.py
--model_net
CycleGAN
--dataset
cityscapes
--batch_size
1
--net_G
resnet_9block
--g_base_dim
32
--net_D
basic
--norm_type
batch_norm
--epoch
200
--
image
_size
286
--crop_size
256
--crop_type
Random
>
log_out 2>log_err
PaddleCV/PaddleGAN/scripts/run_pix2pix.sh
浏览文件 @
4a93b001
python train.py
--model_net
Pix2pix
--dataset
cityscapes
--train_list
data/cityscapes/pix2pix_train_list
--test_list
data/cityscapes/pix2pix_test_list
--crop_type
Random
--dropout
True
--gan_mode
vanilla
--batch_size
1
>
log_out 2>log_err
python train.py
--model_net
Pix2pix
--dataset
cityscapes
--train_list
data/cityscapes/pix2pix_train_list
--test_list
data/cityscapes/pix2pix_test_list
--crop_type
Random
--dropout
True
--gan_mode
vanilla
--batch_size
1
--epoch
200
--image_size
286
--crop_size
256
>
log_out 2>log_err
PaddleCV/PaddleGAN/scripts/run_stargan.sh
浏览文件 @
4a93b001
python train.py
--model_net
StarGAN
--dataset
celeba
--crop_size
178
--image_size
128
--train_list
./data/celeba/list_attr_celeba.txt
--
test_list
./data/celeba/test_list_attr_celeba.txt
--
gan_mode
wgan
--batch_size
16
--epoch
20
>
log_out 2>log_err
python train.py
--model_net
StarGAN
--dataset
celeba
--crop_size
178
--image_size
128
--train_list
./data/celeba/list_attr_celeba.txt
--gan_mode
wgan
--batch_size
16
--epoch
20
>
log_out 2>log_err
PaddleCV/PaddleGAN/scripts/run_stgan.sh
浏览文件 @
4a93b001
python train.py
--model_net
STGAN
--dataset
celeba
--crop_size
170
--image_size
128
--train_list
./data/celeba/list_attr_celeba.txt
--
test_list
./data/celeba/test_list_attr_celeba.txt
--
gan_mode
wgan
--batch_size
32
--print_freq
1
--num_discriminator_time
5
--epoch
50
>
log_out 2>log_err
python train.py
--model_net
STGAN
--dataset
celeba
--crop_size
170
--image_size
128
--train_list
./data/celeba/list_attr_celeba.txt
--gan_mode
wgan
--batch_size
32
--print_freq
1
--num_discriminator_time
5
--epoch
50
>
log_out 2>log_err
PaddleCV/PaddleGAN/train.py
浏览文件 @
4a93b001
...
...
@@ -24,51 +24,38 @@ import time
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
trainer
def
train
(
cfg
):
MODELS
=
[
"CGAN"
,
"DCGAN"
,
"Pix2pix"
,
"CycleGAN"
,
"StarGAN"
,
"AttGAN"
,
"STGAN"
]
if
cfg
.
model_net
not
in
MODELS
:
raise
NotImplementedError
(
"{} is not support!"
.
format
(
cfg
.
model_net
))
reader
=
data_reader
(
cfg
)
if
cfg
.
model_net
==
'CycleGAN'
:
if
cfg
.
model_net
in
[
'CycleGAN'
]:
a_reader
,
b_reader
,
a_reader_test
,
b_reader_test
,
batch_num
=
reader
.
make_data
(
)
elif
cfg
.
model_net
==
'Pix2pix'
:
train_reader
,
test_reader
,
batch_num
=
reader
.
make_data
()
elif
cfg
.
model_net
==
'StarGAN'
:
train_reader
,
test_reader
,
batch_num
=
reader
.
make_data
()
else
:
if
cfg
.
dataset
==
'mnist'
:
if
cfg
.
dataset
in
[
'mnist'
]
:
train_reader
=
reader
.
make_data
()
else
:
train_reader
,
test_reader
,
batch_num
=
reader
.
make_data
()
if
cfg
.
model_net
==
'CGAN'
:
from
trainer.CGAN
import
CGAN
if
cfg
.
dataset
!=
'mnist'
:
raise
NotImplementedError
(
'CGAN only support mnist now!'
)
model
=
CGAN
(
cfg
,
train_reader
)
elif
cfg
.
model_net
==
'DCGAN'
:
from
trainer.DCGAN
import
DCGAN
if
cfg
.
model_net
in
[
'CGAN'
,
'DCGAN'
]:
if
cfg
.
dataset
!=
'mnist'
:
raise
NotImplementedError
(
'DCGAN only support mnist now!'
)
model
=
DCGAN
(
cfg
,
train_reader
)
elif
cfg
.
model_net
==
'CycleGAN'
:
from
trainer.CycleGAN
import
CycleGAN
model
=
CycleGAN
(
cfg
,
a_reader
,
b_reader
,
a_reader_test
,
b_reader_test
,
batch_num
)
elif
cfg
.
model_net
==
'Pix2pix'
:
from
trainer.Pix2pix
import
Pix2pix
model
=
Pix2pix
(
cfg
,
train_reader
,
test_reader
,
batch_num
)
elif
cfg
.
model_net
==
'StarGAN'
:
from
trainer.StarGAN
import
StarGAN
model
=
StarGAN
(
cfg
,
train_reader
,
test_reader
,
batch_num
)
elif
cfg
.
model_net
==
'AttGAN'
:
from
trainer.AttGAN
import
AttGAN
model
=
AttGAN
(
cfg
,
train_reader
,
test_reader
,
batch_num
)
elif
cfg
.
model_net
==
'STGAN'
:
from
trainer.STGAN
import
STGAN
model
=
STGAN
(
cfg
,
train_reader
,
test_reader
,
batch_num
)
raise
NotImplementedError
(
"CGAN/DCGAN only support MNIST now!"
)
model
=
trainer
.
__dict__
[
cfg
.
model_net
](
cfg
,
train_reader
)
elif
cfg
.
model_net
in
[
'CycleGAN'
]:
model
=
trainer
.
__dict__
[
cfg
.
model_net
](
cfg
,
a_reader
,
b_reader
,
a_reader_test
,
b_reader_test
,
batch_num
)
else
:
pass
model
=
trainer
.
__dict__
[
cfg
.
model_net
](
cfg
,
train_reader
,
test_reader
,
batch_num
)
model
.
build_model
()
...
...
@@ -77,13 +64,13 @@ if __name__ == "__main__":
cfg
=
config
.
parse_args
()
config
.
print_arguments
(
cfg
)
utility
.
check_gpu
(
cfg
.
use_gpu
)
#assert cfg.load_size >= cfg.crop_size, "Load Size CANNOT less than Crop Size!"
if
cfg
.
profile
:
if
cfg
.
use_gpu
:
with
profiler
.
profiler
(
'All'
,
'total'
,
'/tmp/profile'
)
as
prof
:
with
fluid
.
profiler
.
profiler
(
'All'
,
'total'
,
'/tmp/profile'
)
as
prof
:
train
(
cfg
)
else
:
with
profiler
.
profiler
(
"CPU"
,
sorted_key
=
'total'
)
as
cpuprof
:
with
fluid
.
profiler
.
profiler
(
"CPU"
,
sorted_key
=
'total'
)
as
cpuprof
:
train
(
cfg
)
else
:
train
(
cfg
)
PaddleCV/PaddleGAN/trainer/AttGAN.py
浏览文件 @
4a93b001
...
...
@@ -175,8 +175,6 @@ class DTrainer():
class
AttGAN
(
object
):
def
add_special_args
(
self
,
parser
):
parser
.
add_argument
(
'--image_size'
,
type
=
int
,
default
=
256
,
help
=
"image size"
)
parser
.
add_argument
(
'--g_lr'
,
type
=
float
,
...
...
PaddleCV/PaddleGAN/trainer/STGAN.py
浏览文件 @
4a93b001
...
...
@@ -173,8 +173,6 @@ class DTrainer():
class
STGAN
(
object
):
def
add_special_args
(
self
,
parser
):
parser
.
add_argument
(
'--image_size'
,
type
=
int
,
default
=
256
,
help
=
"image size"
)
parser
.
add_argument
(
'--g_lr'
,
type
=
float
,
...
...
PaddleCV/PaddleGAN/trainer/StarGAN.py
浏览文件 @
4a93b001
...
...
@@ -199,8 +199,6 @@ class DTrainer():
class
StarGAN
(
object
):
def
add_special_args
(
self
,
parser
):
parser
.
add_argument
(
'--image_size'
,
type
=
int
,
default
=
256
,
help
=
"image size"
)
parser
.
add_argument
(
'--g_lr'
,
type
=
float
,
default
=
0.0001
,
help
=
"learning rate of g"
)
parser
.
add_argument
(
...
...
PaddleCV/PaddleGAN/trainer/__init__.py
浏览文件 @
4a93b001
...
...
@@ -12,6 +12,14 @@
#See the License for the specific language governing permissions and
#limitations under the License.
from
.CGAN
import
CGAN
from
.DCGAN
import
DCGAN
from
.CycleGAN
import
CycleGAN
from
.Pix2pix
import
Pix2pix
from
.STGAN
import
STGAN
from
.StarGAN
import
StarGAN
from
.AttGAN
import
AttGAN
import
importlib
...
...
PaddleCV/PaddleGAN/util/config.py
浏览文件 @
4a93b001
...
...
@@ -77,7 +77,7 @@ def base_parse_args(parser):
add_arg
(
'epoch'
,
int
,
200
,
"The number of epoch to be trained."
)
add_arg
(
'g_base_dims'
,
int
,
64
,
"Base channels in generator"
)
add_arg
(
'd_base_dims'
,
int
,
64
,
"Base channels in discriminator"
)
add_arg
(
'
load
_size'
,
int
,
286
,
"the image size when load the image"
)
add_arg
(
'
image
_size'
,
int
,
286
,
"the image size when load the image"
)
add_arg
(
'crop_type'
,
str
,
'Centor'
,
"the crop type, choose = ['Centor', 'Random']"
)
add_arg
(
'crop_size'
,
int
,
256
,
"crop size when preprocess image"
)
...
...
PaddleCV/PaddleGAN/util/utility.py
浏览文件 @
4a93b001
...
...
@@ -113,9 +113,10 @@ def save_test_image(epoch,
images
=
[
real_img_temp
]
for
i
in
range
(
cfg
.
c_dim
):
label_trg_tmp
=
copy
.
deepcopy
(
label_org
)
label_trg_tmp
[
0
][
i
]
=
1.0
-
label_trg_tmp
[
0
][
i
]
label_trg
=
check_attribute_conflict
(
label_trg_tmp
,
attr_names
[
i
],
attr_names
)
for
j
in
range
(
len
(
label_org
)):
label_trg_tmp
[
j
][
i
]
=
1.0
-
label_trg_tmp
[
j
][
i
]
label_trg
=
check_attribute_conflict
(
label_trg_tmp
,
attr_names
[
i
],
attr_names
)
tensor_label_trg
=
fluid
.
LoDTensor
()
tensor_label_trg
.
set
(
label_trg
,
place
)
fake_temp
,
rec_temp
=
exe
.
run
(
...
...
@@ -126,11 +127,13 @@ def save_test_image(epoch,
"label_trg"
:
tensor_label_trg
},
fetch_list
=
[
g_trainer
.
fake_img
,
g_trainer
.
rec_img
])
fake_temp
=
save_batch_image
(
fake_temp
[
0
]
)
rec_temp
=
save_batch_image
(
rec_temp
[
0
]
)
fake_temp
=
save_batch_image
(
fake_temp
)
rec_temp
=
save_batch_image
(
rec_temp
)
images
.
append
(
fake_temp
)
images
.
append
(
rec_temp
)
images_concat
=
np
.
concatenate
(
images
,
1
)
if
len
(
label_org
)
>
1
:
images_concat
=
np
.
concatenate
(
images_concat
,
1
)
imageio
.
imwrite
(
out_path
+
"/fake_img"
+
str
(
epoch
)
+
"_"
+
name
[
0
],
((
images_concat
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
elif
cfg
.
model_net
==
'AttGAN'
or
cfg
.
model_net
==
'STGAN'
:
...
...
@@ -184,12 +187,12 @@ def save_test_image(epoch,
else
:
for
data_A
,
data_B
in
zip
(
A_test_reader
(),
B_test_reader
()):
A_
name
=
data_A
[
0
][
1
]
B_
name
=
data_B
[
0
][
1
]
A_
data
,
A_name
=
data_A
B_
data
,
B_name
=
data_B
tensor_A
=
fluid
.
LoDTensor
()
tensor_B
=
fluid
.
LoDTensor
()
tensor_A
.
set
(
data_A
[
0
][
0
]
,
place
)
tensor_B
.
set
(
data_B
[
0
][
0
]
,
place
)
tensor_A
.
set
(
A_data
,
place
)
tensor_B
.
set
(
B_data
,
place
)
fake_A_temp
,
fake_B_temp
,
cyc_A_temp
,
cyc_B_temp
=
exe
.
run
(
test_program
,
fetch_list
=
[
...
...
@@ -205,18 +208,20 @@ def save_test_image(epoch,
input_A_temp
=
np
.
squeeze
(
data_A
[
0
][
0
]).
transpose
([
1
,
2
,
0
])
input_B_temp
=
np
.
squeeze
(
data_B
[
0
][
0
]).
transpose
([
1
,
2
,
0
])
imageio
.
imwrite
(
out_path
+
"/fakeB_"
+
str
(
epoch
)
+
"_"
+
A_name
,
(
(
fake_B_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/fakeA_"
+
str
(
epoch
)
+
"_"
+
B_name
,
(
(
fake_A_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/cycA_"
+
str
(
epoch
)
+
"_"
+
A_name
,
(
(
cyc_A_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/cycB_"
+
str
(
epoch
)
+
"_"
+
B_name
,
(
(
cyc_B_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/inputA_"
+
str
(
epoch
)
+
"_"
+
A_name
,
(
(
input_A_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/inputB_"
+
str
(
epoch
)
+
"_"
+
B_name
,
(
(
input_B_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/fakeB_"
+
str
(
epoch
)
+
"_"
+
A_name
[
0
],
((
fake_B_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/fakeA_"
+
str
(
epoch
)
+
"_"
+
B_name
[
0
],
((
fake_A_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/cycA_"
+
str
(
epoch
)
+
"_"
+
A_name
[
0
],
((
cyc_A_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/cycB_"
+
str
(
epoch
)
+
"_"
+
B_name
[
0
],
((
cyc_B_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/inputA_"
+
str
(
epoch
)
+
"_"
+
A_name
[
0
],
(
(
input_A_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
imageio
.
imwrite
(
out_path
+
"/inputB_"
+
str
(
epoch
)
+
"_"
+
B_name
[
0
],
(
(
input_B_temp
+
1
)
*
127.5
).
astype
(
np
.
uint8
))
class
ImagePool
(
object
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录