未验证 提交 1c13f704 编写于 作者: L lvmengsi 提交者: GitHub

fix some bug of gan (#2490)

* fix some bug of gan
上级 2df1c9c3
#图像生成模型库
# 图像生成模型库
生成对抗网络(Generative Adversarial Network\[[1](#参考文献)\], 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。
生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。\[[2](#参考文献)\]
---
##内容
## 内容
-[简介](#简介)
-[快速开始](#快速开始)
-[参考文献](#参考文献)
##简介
## 简介
本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\]
......@@ -54,12 +54,12 @@
```
##快速开始
## 快速开始
**安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle):**
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
###数据准备
### 数据准备
模型库中提供了download.py数据下载脚本,该脚本支持下载MNIST数据集,CycleGAN和Pix2Pix所需要的数据集。使用以下命令下载数据:
python download.py --dataset=mnist
......@@ -70,12 +70,12 @@ StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects
**自定义数据集:**
用户可以使用自定义的数据集,只要设置成所对应的生成模型所需要的数据格式即可。
ps: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里的make_pair_data.py来生成,可以使用以下命令来生成:
注意: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里的make_pair_data.py来生成,可以使用以下命令来生成:
python scripts/make_pair_data.py \
--direction=A2B
用户可以通过指定direction参数生成list文件,从而确保图像风格转变的方向。
###模型训练
### 模型训练
**开始训练:** 数据准备完毕后,可以通过一下方式启动训练:
python train.py \
--model_net=$(name_of_model) \
......@@ -87,7 +87,7 @@ ps: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里
用户可以通过设置model_net参数来选择想要训练的模型,通过设置dataset参数来选择训练所需要的数据集。
###模型测试
### 模型测试
模型测试是利用训练完成的生成模型进行图像生成。infer.py是主要的执行程序,调用示例如下:
python infer.py \
--model_net=$(name_of_model) \
......@@ -95,13 +95,21 @@ ps: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里
--dataset_dir=$(path_to_data)
##参考文献
## 参考文献
[1] [Goodfellow, Ian J.; Pouget-Abadie, Jean; Mirza, Mehdi; Xu, Bing; Warde-Farley, David; Ozair, Sherjil; Courville, Aaron; Bengio, Yoshua. Generative Adversarial Networks. 2014. arXiv:1406.2661 [stat.ML].](https://arxiv.org/abs/1406.2661)
[2] [https://zh.wikipedia.org/wiki/生成对抗网络](https://zh.wikipedia.org/wiki/生成对抗网络)
[3] [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
[4] [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434)
[5] [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004)
[6] [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)
[7] [StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation](https://arxiv.org/abs/1711.09020)
[8] [AttGAN: Facial Attribute Editing by Only Changing What You Want](https://arxiv.org/abs/1711.10678)
[9] [STGAN: A Unified Selective Transfer Network for Arbitrary Image Attribute Editing](https://arxiv.org/abs/1904.09709)
......@@ -262,7 +262,7 @@ class celeba_reader_creator(reader_creator):
attr_names = args.selected_attrs.split(',')
for line in lines:
arr = line.strip().split()
name = './images/' + arr[0]
name = os.path.join('img_align_celeba', arr[0])
label = []
for attr_name in attr_names:
idx = attr2idx[attr_name]
......@@ -318,19 +318,11 @@ class celeba_reader_creator(reader_creator):
batch_out_2 = []
batch_out_3 = []
for file, label in self.images:
if args.model_net == 'StarGAN':
img = Image.open(os.path.join(self.image_dir, file))
label = np.array(label).astype("float32")
img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size),
Image.BILINEAR)
else:
img = Image.open(os.path.join(self.image_dir,
file)).convert('RGB')
label = np.array(label).astype("float32")
img = CentorCrop(img, 170, 170)
img = img.resize((args.image_size, args.image_size),
Image.BILINEAR)
img = Image.open(os.path.join(self.image_dir, file))
label = np.array(label).astype("float32")
img = CentorCrop(img, args.crop_size, args.crop_size)
img = img.resize((args.image_size, args.image_size),
Image.BILINEAR)
img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
img = img.transpose([2, 0, 1])
if return_name:
......@@ -523,3 +515,22 @@ class data_reader(object):
reader = train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle)
return reader, reader_test, batch_num
else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None:
train_list = self.cfg.train_list
train_reader = reader_creator(
image_dir=dataset_dir, list_filename=train_list)
reader_test = None
if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt")
test_reader = reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
batch_size=1,
drop_last=self.cfg.drop_last)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
batch_num = train_reader.len()
return train_reader, reader_test, batch_num
......@@ -60,7 +60,7 @@ class StarGAN_model(object):
input1 = fluid.layers.concat([input, label_trg_e], 1)
conv0 = conv2d(
input1,
cfg.g_conv_dim,
cfg.g_base_dims,
7,
1,
padding=3,
......@@ -74,7 +74,7 @@ class StarGAN_model(object):
rate = 2**(i + 1)
conv_down = conv2d(
conv_down,
cfg.g_conv_dim * rate,
cfg.g_base_dims * rate,
4,
2,
padding=1,
......@@ -86,13 +86,15 @@ class StarGAN_model(object):
res_block = conv_down
for i in range(repeat_num):
res_block = self.ResidualBlock(
res_block, cfg.g_conv_dim * (2**2), name=name + '.%d' % (i + 9))
res_block,
cfg.g_base_dims * (2**2),
name=name + '.%d' % (i + 9))
deconv = res_block
for i in range(2):
rate = 2**(1 - i)
deconv = deconv2d(
deconv,
cfg.g_conv_dim * rate,
cfg.g_base_dims * rate,
4,
2,
padding=1,
......@@ -117,7 +119,7 @@ class StarGAN_model(object):
def network_D(self, input, cfg, name="discriminator"):
conv0 = conv2d(
input,
cfg.d_conv_dim,
cfg.d_base_dims,
4,
2,
padding=1,
......@@ -125,7 +127,7 @@ class StarGAN_model(object):
name=name + '0',
initial='kaiming')
repeat_num = 6
curr_dim = cfg.d_conv_dim
curr_dim = cfg.d_base_dims
conv = conv0
for i in range(1, repeat_num):
curr_dim *= 2
......
CUDA_VISIBLE_DEVICES=2 python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 200 > log_out 2>log_err
#CUDA_VISIBLE_DEVICES=0 python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./test_list --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 2 --epoch 200 > log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 200 > log_out 2>log_err
......@@ -80,8 +80,6 @@ class GTrainer():
beta2=0.999,
name="net_G")
optimizer.minimize(self.g_loss, parameter_list=vars)
with open('program_gen.txt', 'w') as f:
print(self.program, file=f)
class DTrainer():
......@@ -164,8 +162,6 @@ class DTrainer():
name="net_D")
optimizer.minimize(self.d_loss, parameter_list=vars)
with open('program_dis.txt', 'w') as f:
print(self.program, file=f)
def gradient_penalty(self, f, real, fake, cfg=None, name=None):
def _interpolate(a, b):
......@@ -213,16 +209,6 @@ class StarGAN(object):
type=int,
default=5,
help="the number of attributes we selected")
parser.add_argument(
'--g_conv_dim',
type=int,
default=64,
help="base conv dims in generator")
parser.add_argument(
'--d_conv_dim',
type=int,
default=64,
help="base conv dims in discriminator")
parser.add_argument(
'--g_repeat_num',
type=int,
......@@ -304,9 +290,6 @@ class StarGAN(object):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
with open('program.txt', "w") as f:
print(gen_trainer.program, file=f)
if self.cfg.init_model:
utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G")
utility.init_checkpoints(self.cfg, exe, dis_trainer, "net_D")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册