未验证 提交 23d450e9 编写于 作者: L lvmengsi 提交者: GitHub

Cherry pick gan0717 (#2831)

update cherry pick
上级 c399ca44
......@@ -18,9 +18,11 @@
本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\]
注意:
1. AttGAN和STGAN的网络结构中,判别器去掉了instance norm。
2. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
注意:
1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
2. CGAN和DCGAN仅支持多batch size训练。
3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集,测试集列表(test_list)和下载到的list文件格式相同,即包含测试集数量,属性列表,想要进行测试的图片和标签。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。
4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。
图像生成模型库库的目录结构如下:
```
......@@ -58,7 +60,7 @@
### 安装说明
**安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle):**
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
### 任务简介
......@@ -141,11 +143,19 @@ Pix2Pix和CycleGAN的效果图
StarGAN,AttGAN和STGAN的效果如图所示:
<p align="center">
<img src="images/female_stargan_attgan_stgan.png" width="650"/><br />
StarGAN,AttGAN和STGAN的效果图
<img src="images/stargan.jpg" width="500"/><br />
StarGAN的效果图(图片属性分别为:origial image, Black hair, Blond Hair, Brown Hair, Male, Young)
</p>
<p align="center">
<img src="images/attgan.jpg" width="1250"/><br />
AttGAN的效果图(图片属性分别为:original image, Bald, Bangs, Black Hair, Blond Hair, Brown Hair, Bushy Eyebrows, Eyeglasses, Male, Mouth Slightly Open, Mustache, No Beard, Pale Skin, Young)
</p>
<p align="center">
<img src="images/stgan.jpg" width="1250"/><br />
STGAN的效果图(图片属性分别为:original image, Bald, Bangs, Black Hair, Blond Hair, Brown Hair, Bushy Eyebrows, Eyeglasses, Male, Mouth Slightly Open, Mustache, No Beard, Pale Skin, Young)
</p>
- 每个GAN都给出了一份测试示例,放在scripts文件夹内,用户可以直接运行测试脚本得到测试结果。
......@@ -181,7 +191,7 @@ STGAN只输入有变化的标签,引入GRU结构,更好的选择变化的属
### 模型概览
- Pix2Pix由一个生成网络和一个判别网络组成。生成网络中编码部分的网络结构都是采用`convolution-batch norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-batch norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/Pix2pix_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。网络利用损失函数学习从输入图像到输出图像的映射,生成网络损失函数由CGAN的损失函数和L1损失函数组成,判别网络损失函数由CGAN的损失函数组成。生成器的网络结构如下图所示:
- Pix2Pix由一个生成网络和一个判别网络组成。生成网络中编码部分的网络结构都是采用`convolution-batch norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-batch norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/Pix2pix_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。网络利用损失函数学习从输入图像到输出图像的映射,生成网络损失函数由GAN的损失函数和L1损失函数组成,判别网络损失函数由GAN的损失函数组成。生成器的网络结构如下图所示:
<p align="center">
<img src="images/pix2pix_gen.png" width="550"/><br />
......@@ -189,7 +199,7 @@ Pix2Pix生成网络结构图[5]
</p>
- CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。生成网络中编码部分的网络结构都是采用`convolution-norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/CycleGAN_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。生成网络损失函数由CGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由CGAN的损失函数组成。
- CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。生成网络中编码部分的网络结构都是采用`convolution-norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/CycleGAN_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。生成网络损失函数由LSGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由LSGAN的损失函数组成。
<p align="center">
<img src="images/pix2pix_gen.png" width="550"/><br />
......@@ -197,7 +207,7 @@ CycleGAN生成网络结构图[5]
</p>
- StarGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分主要由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/StarGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
- StarGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分主要由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/StarGAN_network.py`文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
<p align="center">
<img src="images/stargan_gen.png" width=350 />
......@@ -207,7 +217,7 @@ StarGAN的生成网络结构[左]和判别网络结构[右] [7]
- AttGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/AttGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
- AttGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/AttGAN_network.py`文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
<p align="center">
<img src="images/attgan_net.png" width=800 /> <br />
......@@ -215,7 +225,7 @@ AttGAN的网络结构[8]
</p>
- STGAN中生成网络再编码器和解码器之间加入Selective Transfer Units\(STU\),有选择的转换编码网络,从而更好的适配解码网络。生成网络中的编码网络主要由`convolution-instance norm-ReLU`组成,解码网络主要由`transpose convolution-norm-leaky_ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/STGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
- STGAN中生成网络再编码器和解码器之间加入Selective Transfer Units\(STU\),有选择的转换编码网络,从而更好的适配解码网络。生成网络中的编码网络主要由`convolution-instance norm-ReLU`组成,解码网络主要由`transpose convolution-norm-leaky_ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/STGAN_network.py`文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
<p align="center">
<img src="images/stgan_net.png" width=800 /> <br />
......@@ -228,17 +238,20 @@ STGAN的网络结构[9]
## FAQ
**Q:** StarGAN/AttGAN/STGAN中属性没有变化,为什么?
**Q:** StarGAN/AttGAN/STGAN中属性没有变化,为什么?
**A:** 查看是否所有的标签都转换对了。
**Q:** 预测结果不正常,是怎么回事?
**A:** 某些GAN预测的时候batch_norm的设置需要和训练的时候行为一致,查看模型库中相应的GAN中预测时batch_norm的行为和自己模型中的预测时batch_norm的
行为是否一致。
行为是否一致。
**Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢?
**A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,保持要转变的标签不变即可。
**Q:** 如何使用自己的数据集进行训练?
**A:** 对于Pix2Pix来说,只要准备好类似于Cityscapes数据集的不同风格的成对的数据即可。对于CycleGAN来说,只要准备类似于Cityscapes数据集的不同风格的数据即可。对于StarGAN,AttGAN和STGAN来说,除了需要准备类似于CelebA数据集中的图片和标签文件外,还需要把模型中的selected_attrs参数设置为想要改变的目标属性,c_dim参数这是为目标属性的个数。
## 参考论文
[1] [Goodfellow, Ian J.; Pouget-Abadie, Jean; Mirza, Mehdi; Xu, Bing; Warde-Farley, David; Ozair, Sherjil; Courville, Aaron; Bengio, Yoshua. Generative Adversarial Networks. 2014. arXiv:1406.2661 [stat.ML].](https://arxiv.org/abs/1406.2661)
......
......@@ -27,7 +27,7 @@ import imageio
import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict, check_gpu
from util.utility import check_attribute_conflict, check_gpu, save_batch_image
import copy
parser = argparse.ArgumentParser(description=__doc__)
......@@ -118,7 +118,6 @@ def infer(args):
print(args.init_model + '/' + model_name)
fluid.io.load_persistables(exe, args.init_model + "/" + model_name)
print('load params done')
if not os.path.exists(args.output):
os.makedirs(args.output)
......@@ -144,7 +143,7 @@ def infer(args):
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_trg)
......@@ -152,11 +151,13 @@ def infer(args):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org))
label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
for j in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0
tensor_label_org_.set(label_org, place)
if args.model_net == 'AttGAN':
for k in range(len(label_org)):
label_trg_[k][i] = label_trg_[k][i] * 2.0
tensor_label_org_.set(label_org_, place)
tensor_label_trg.set(label_trg, place)
tensor_label_trg_.set(label_trg_, place)
out = exe.run(feed={
......@@ -165,10 +166,11 @@ def infer(args):
"label_trg_": tensor_label_trg_
},
fetch_list=[fake.name])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
images_concat = np.concatenate(images_concat, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'StarGAN':
......@@ -187,7 +189,7 @@ def infer(args):
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
......@@ -201,10 +203,11 @@ def infer(args):
feed={"input": tensor_img,
"label_trg_": tensor_label_trg},
fetch_list=[fake.name])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
images_concat = np.concatenate(images_concat, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
......
......@@ -55,6 +55,7 @@ class AttGAN_model(object):
name=name,
dim=cfg.d_base_dims,
fc_dim=cfg.d_fc_dim,
norm=cfg.dis_norm,
n_layers=cfg.n_layers)
def concat(self, z, a):
......@@ -149,11 +150,11 @@ class AttGAN_model(object):
d,
4,
2,
norm=None,
norm=norm,
padding=1,
activation_fn='leaky_relu',
name=name + str(i),
use_bias=True,
use_bias=(norm == None),
relufactor=0.01,
initial='kaiming')
......
......@@ -35,6 +35,10 @@ class CGAN_model(object):
self.gf_dim = 128
self.df_dim = 64
self.leaky_relu_factor = 0.2
if self.batch_size == 1:
self.norm = None
else:
self.norm = "batch_norm"
def network_G(self, input, label, name="generator"):
# concat noise and label
......@@ -43,14 +47,14 @@ class CGAN_model(object):
o_l1 = linear(
xy,
self.gf_dim * 8,
norm='batch_norm',
norm=self.norm,
activation_fn='relu',
name=name + '_l1')
o_c1 = fluid.layers.concat([o_l1, y], 1)
o_l2 = linear(
o_c1,
self.gf_dim * (self.img_w // 4) * (self.img_h // 4),
norm='batch_norm',
norm=self.norm,
activation_fn='relu',
name=name + '_l2')
o_r1 = fluid.layers.reshape(
......@@ -107,7 +111,7 @@ class CGAN_model(object):
o_l3 = linear(
o_c2,
self.df_dim * 16,
norm='batch_norm',
norm=self.norm,
activation_fn='leaky_relu',
name=name + '_l3')
o_c3 = fluid.layers.concat([o_l3, y], 1)
......
......@@ -31,13 +31,17 @@ class DCGAN_model(object):
self.dfc_dim = 1024
self.gf_dim = 64
self.df_dim = 64
if self.batch_size == 1:
self.norm = None
else:
self.norm = "batch_norm"
def network_G(self, input, name="generator"):
o_l1 = linear(input, self.gfc_dim, norm='batch_norm', name=name + '_l1')
o_l1 = linear(input, self.gfc_dim, norm=self.norm, name=name + '_l1')
o_l2 = linear(
o_l1,
self.gf_dim * 2 * self.img_dim // 4 * self.img_dim // 4,
norm='batch_norm',
norm=self.norm,
name=name + '_l2')
o_r1 = fluid.layers.reshape(
o_l2, [-1, self.df_dim * 2, self.img_dim // 4, self.img_dim // 4])
......@@ -85,7 +89,7 @@ class DCGAN_model(object):
o_l1 = linear(
o_c2,
self.dfc_dim,
norm='batch_norm',
norm=self.norm,
activation_fn='leaky_relu',
name=name + '_l1')
out = linear(o_l1, 1, activation_fn='sigmoid', name=name + '_l2')
......
......@@ -76,6 +76,7 @@ class STGAN_model(object):
n_atts=cfg.c_dim,
dim=cfg.d_base_dims,
fc_dim=cfg.d_fc_dim,
norm=cfg.dis_norm,
n_layers=cfg.n_layers,
name=name)
......@@ -100,7 +101,7 @@ class STGAN_model(object):
activation_fn='leaky_relu',
name=name + str(i),
use_bias=False,
relufactor=0.01,
relufactor=0.2,
initial='kaiming',
is_test=is_test)
zs.append(z)
......@@ -132,7 +133,7 @@ class STGAN_model(object):
pass_state=pass_state,
name=name + str(i),
is_test=is_test)
zs_.insert(0, output[0] + zs[n_layers - 1 - i])
zs_.insert(0, output[0])
if inject_layers > i:
state = self.concat(output[1], a)
else:
......@@ -202,18 +203,18 @@ class STGAN_model(object):
d,
4,
2,
norm=None,
padding=1,
norm=norm,
padding_type="SAME",
activation_fn='leaky_relu',
name=name + str(i),
use_bias=True,
relufactor=0.01,
use_bias=(norm == None),
relufactor=0.2,
initial='kaiming')
logit_gan = linear(
y,
fc_dim,
activation_fn='relu',
activation_fn='leaky_relu',
name=name + 'fc_adv_1',
initial='kaiming')
logit_gan = linear(
......@@ -222,7 +223,7 @@ class STGAN_model(object):
logit_att = linear(
y,
fc_dim,
activation_fn='relu',
activation_fn='leaky_relu',
name=name + 'fc_cls_1',
initial='kaiming')
logit_att = linear(
......
......@@ -76,7 +76,7 @@ def norm_layer(input, norm_type='batch_norm', name=None, is_test=False):
tmp = fluid.layers.elementwise_add(tmp, offset, axis=1)
return tmp
else:
raise NotImplementedError("norm tyoe: [%s] is not support" % norm_type)
raise NotImplementedError("norm type: [%s] is not support" % norm_type)
def initial_type(name,
......
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 90 >log.out 2>log_err
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 90 >log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 20 >log.out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err
......@@ -229,6 +229,12 @@ class AttGAN(object):
type=int,
default=5,
help="default layers in the network")
parser.add_argument(
'--dis_norm',
type=str,
default=None,
help="the normalization in discriminator, choose in [None, instance_norm]"
)
return parser
......
......@@ -79,7 +79,7 @@ class DTrainer():
clone_image_real = b.var('image_real')
break
self.fake_img, _ = model.network_G(
image_real, label_org, label_trg_, cfg, name="generator")
image_real, label_org_, label_trg_, cfg, name="generator")
self.pred_real, self.cls_real = model.network_D(
image_real, cfg, name="discriminator")
self.pred_real.persistable = True
......@@ -234,6 +234,12 @@ class STGAN(object):
type=int,
default=4,
help="default layers of GRU in generotor")
parser.add_argument(
'--dis_norm',
type=str,
default=None,
help="the normalization in discriminator, choose in [None, instance_norm]"
)
return parser
......
......@@ -109,13 +109,13 @@ def save_test_image(epoch,
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([1, 2, 0])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[0][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_trg = check_attribute_conflict(label_trg_tmp,
attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run(
......@@ -126,8 +126,8 @@ def save_test_image(epoch,
"label_trg": tensor_label_trg
},
fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
rec_temp = np.squeeze(rec_temp[0]).transpose([1, 2, 0])
fake_temp = save_batch_image(fake_temp[0])
rec_temp = save_batch_image(rec_temp[0])
images.append(fake_temp)
images.append(rec_temp)
images_concat = np.concatenate(images, 1)
......@@ -145,7 +145,7 @@ def save_test_image(epoch,
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_trg)
......@@ -155,12 +155,14 @@ def save_test_image(epoch,
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org))
label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
for j in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0
tensor_label_org_.set(label_org, place)
if cfg.model_net == 'AttGAN':
for k in range(len(label_org)):
label_trg_[k][i] = label_trg_[k][i] * 2.0
tensor_label_org_.set(label_org_, place)
tensor_label_trg.set(label_trg, place)
tensor_label_trg_.set(label_trg_, place)
out = exe.run(test_program,
......@@ -172,10 +174,11 @@ def save_test_image(epoch,
"label_trg_": tensor_label_trg_
},
fetch_list=[g_trainer.fake_img])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
images_concat = np.concatenate(images_concat, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + '_' + name[0],
((images_concat + 1) * 127.5).astype(np.uint8))
......@@ -266,20 +269,28 @@ def check_attribute_conflict(label_batch, attr, attrs):
return label_batch
def save_batch_image(img):
if len(img) == 1:
res_img = np.squeeze(img).transpose([1, 2, 0])
else:
res_img = np.squeeze(img).transpose([0, 2, 3, 1])
return res_img
def check_gpu(use_gpu):
"""
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册