未验证 提交 fd84c497 编写于 作者: L lvmengsi 提交者: GitHub

add instance norm (#2804)

* add instance norm
上级 33d384b4
......@@ -18,9 +18,11 @@
本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\]
注意:
1. AttGAN和STGAN的网络结构中,判别器去掉了instance norm。
2. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
注意:
1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
2. CGAN和DCGAN仅支持多batch size训练。
3. CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集,测试集列表(test_list)和下载到的list文件格式相同,即包含测试集数量,属性列表,想要进行测试的图片和标签。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。
4. PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装。
图像生成模型库库的目录结构如下:
```
......@@ -58,7 +60,7 @@
### 安装说明
**安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle):**
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
### 任务简介
......@@ -181,7 +183,7 @@ STGAN只输入有变化的标签,引入GRU结构,更好的选择变化的属
### 模型概览
- Pix2Pix由一个生成网络和一个判别网络组成。生成网络中编码部分的网络结构都是采用`convolution-batch norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-batch norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/Pix2pix_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。网络利用损失函数学习从输入图像到输出图像的映射,生成网络损失函数由CGAN的损失函数和L1损失函数组成,判别网络损失函数由CGAN的损失函数组成。生成器的网络结构如下图所示:
- Pix2Pix由一个生成网络和一个判别网络组成。生成网络中编码部分的网络结构都是采用`convolution-batch norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-batch norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/Pix2pix_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。网络利用损失函数学习从输入图像到输出图像的映射,生成网络损失函数由GAN的损失函数和L1损失函数组成,判别网络损失函数由GAN的损失函数组成。生成器的网络结构如下图所示:
<p align="center">
<img src="images/pix2pix_gen.png" width="550"/><br />
......@@ -189,7 +191,7 @@ Pix2Pix生成网络结构图[5]
</p>
- CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。生成网络中编码部分的网络结构都是采用`convolution-norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/CycleGAN_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。生成网络损失函数由CGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由CGAN的损失函数组成。
- CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。生成网络中编码部分的网络结构都是采用`convolution-norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/CycleGAN_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。生成网络损失函数由LSGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由LSGAN的损失函数组成。
<p align="center">
<img src="images/pix2pix_gen.png" width="550"/><br />
......@@ -197,7 +199,7 @@ CycleGAN生成网络结构图[5]
</p>
- StarGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分主要由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/StarGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
- StarGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分主要由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/StarGAN_network.py`文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
<p align="center">
<img src="images/stargan_gen.png" width=350 />
......@@ -207,7 +209,7 @@ StarGAN的生成网络结构[左]和判别网络结构[右] [7]
- AttGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/AttGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
- AttGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/AttGAN_network.py`文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
<p align="center">
<img src="images/attgan_net.png" width=800 /> <br />
......@@ -215,7 +217,7 @@ AttGAN的网络结构[8]
</p>
- STGAN中生成网络再编码器和解码器之间加入Selective Transfer Units\(STU\),有选择的转换编码网络,从而更好的适配解码网络。生成网络中的编码网络主要由`convolution-instance norm-ReLU`组成,解码网络主要由`transpose convolution-norm-leaky_ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/STGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
- STGAN中生成网络再编码器和解码器之间加入Selective Transfer Units\(STU\),有选择的转换编码网络,从而更好的适配解码网络。生成网络中的编码网络主要由`convolution-instance norm-ReLU`组成,解码网络主要由`transpose convolution-norm-leaky_ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/STGAN_network.py`文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
<p align="center">
<img src="images/stgan_net.png" width=800 /> <br />
......@@ -228,12 +230,12 @@ STGAN的网络结构[9]
## FAQ
**Q:** StarGAN/AttGAN/STGAN中属性没有变化,为什么?
**Q:** StarGAN/AttGAN/STGAN中属性没有变化,为什么?
**A:** 查看是否所有的标签都转换对了。
**Q:** 预测结果不正常,是怎么回事?
**A:** 某些GAN预测的时候batch_norm的设置需要和训练的时候行为一致,查看模型库中相应的GAN中预测时batch_norm的行为和自己模型中的预测时batch_norm的
行为是否一致。
行为是否一致。
**Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢?
**A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女
......
......@@ -118,7 +118,6 @@ def infer(args):
print(args.init_model + '/' + model_name)
fluid.io.load_persistables(exe, args.init_model + "/" + model_name)
print('load params done')
if not os.path.exists(args.output):
os.makedirs(args.output)
......@@ -144,7 +143,6 @@ def infer(args):
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(args.c_dim):
......@@ -153,11 +151,13 @@ def infer(args):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org))
label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
for j in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0
tensor_label_org_.set(label_org, place)
if args.model_net == 'AttGAN':
for j in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0
tensor_label_org_.set(label_org_, place)
tensor_label_trg.set(label_trg, place)
tensor_label_trg_.set(label_trg_, place)
out = exe.run(feed={
......@@ -189,7 +189,6 @@ def infer(args):
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(args.c_dim):
......
......@@ -55,6 +55,7 @@ class AttGAN_model(object):
name=name,
dim=cfg.d_base_dims,
fc_dim=cfg.d_fc_dim,
norm=cfg.dis_norm,
n_layers=cfg.n_layers)
def concat(self, z, a):
......@@ -149,11 +150,11 @@ class AttGAN_model(object):
d,
4,
2,
norm=None,
norm=norm,
padding=1,
activation_fn='leaky_relu',
name=name + str(i),
use_bias=True,
use_bias=(norm == None),
relufactor=0.01,
initial='kaiming')
......
......@@ -76,6 +76,7 @@ class STGAN_model(object):
n_atts=cfg.c_dim,
dim=cfg.d_base_dims,
fc_dim=cfg.d_fc_dim,
norm=cfg.dis_norm,
n_layers=cfg.n_layers,
name=name)
......@@ -100,7 +101,7 @@ class STGAN_model(object):
activation_fn='leaky_relu',
name=name + str(i),
use_bias=False,
relufactor=0.01,
relufactor=0.2,
initial='kaiming',
is_test=is_test)
zs.append(z)
......@@ -132,7 +133,7 @@ class STGAN_model(object):
pass_state=pass_state,
name=name + str(i),
is_test=is_test)
zs_.insert(0, output[0] + zs[n_layers - 1 - i])
zs_.insert(0, output[0])
if inject_layers > i:
state = self.concat(output[1], a)
else:
......@@ -202,18 +203,18 @@ class STGAN_model(object):
d,
4,
2,
norm=None,
padding=1,
norm=norm,
padding_type="SAME",
activation_fn='leaky_relu',
name=name + str(i),
use_bias=True,
relufactor=0.01,
use_bias=(norm == None),
relufactor=0.2,
initial='kaiming')
logit_gan = linear(
y,
fc_dim,
activation_fn='relu',
activation_fn='leaky_relu',
name=name + 'fc_adv_1',
initial='kaiming')
logit_gan = linear(
......@@ -222,7 +223,7 @@ class STGAN_model(object):
logit_att = linear(
y,
fc_dim,
activation_fn='relu',
activation_fn='leaky_relu',
name=name + 'fc_cls_1',
initial='kaiming')
logit_att = linear(
......
......@@ -76,7 +76,7 @@ def norm_layer(input, norm_type='batch_norm', name=None, is_test=False):
tmp = fluid.layers.elementwise_add(tmp, offset, axis=1)
return tmp
else:
raise NotImplementedError("norm tyoe: [%s] is not support" % norm_type)
raise NotImplementedError("norm type: [%s] is not support" % norm_type)
def initial_type(name,
......
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 90 >log.out 2>log_err
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 90 >log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 20 >log.out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err
......@@ -229,6 +229,12 @@ class AttGAN(object):
type=int,
default=5,
help="default layers in the network")
parser.add_argument(
'--dis_norm',
type=str,
default=None,
help="the normalization in discriminator, choose in [None, instance_norm]"
)
return parser
......
......@@ -79,7 +79,7 @@ class DTrainer():
clone_image_real = b.var('image_real')
break
self.fake_img, _ = model.network_G(
image_real, label_org, label_trg_, cfg, name="generator")
image_real, label_org_, label_trg_, cfg, name="generator")
self.pred_real, self.cls_real = model.network_D(
image_real, cfg, name="discriminator")
self.pred_real.persistable = True
......@@ -234,6 +234,12 @@ class STGAN(object):
type=int,
default=4,
help="default layers of GRU in generotor")
parser.add_argument(
'--dis_norm',
type=str,
default=None,
help="the normalization in discriminator, choose in [None, instance_norm]"
)
return parser
......
......@@ -109,7 +109,7 @@ def save_test_image(epoch,
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([1, 2, 0])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
......@@ -126,8 +126,8 @@ def save_test_image(epoch,
"label_trg": tensor_label_trg
},
fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
rec_temp = np.squeeze(rec_temp[0]).transpose([1, 2, 0])
fake_temp = save_batch_image(fake_temp[0])
rec_temp = save_batch_image(rec_temp[0])
images.append(fake_temp)
images.append(rec_temp)
images_concat = np.concatenate(images, 1)
......@@ -145,7 +145,7 @@ def save_test_image(epoch,
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_trg)
......@@ -155,12 +155,14 @@ def save_test_image(epoch,
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org))
label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
for j in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0
tensor_label_org_.set(label_org, place)
if cfg.model_net == 'AttGAN':
for j in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0
tensor_label_org_.set(label_org_, place)
tensor_label_trg.set(label_trg, place)
tensor_label_trg_.set(label_trg_, place)
out = exe.run(test_program,
......@@ -172,10 +174,11 @@ def save_test_image(epoch,
"label_trg_": tensor_label_trg_
},
fetch_list=[g_trainer.fake_img])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
images_concat = np.concatenate(images_concat, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + '_' + name[0],
((images_concat + 1) * 127.5).astype(np.uint8))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册