README.md 13.7 KB
Newer Older
L
lvmengsi 已提交
1
# 图像生成模型库
L
lvmengsi 已提交
2 3 4 5 6

生成对抗网络(Generative Adversarial Network\[[1](#参考文献)\], 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。
生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。\[[2](#参考文献)\]

---
L
lvmengsi 已提交
7
## 内容
L
lvmengsi 已提交
8

L
lvmengsi 已提交
9 10 11 12 13 14 15
- [模型简介](#模型简介)
- [快速开始](#快速开始)
- [进阶使用](#进阶使用)
- [FAQ](#faq)
- [参考论文](#参考论文)
- [版本更新](#版本更新)
- [作者](#作者)
L
lvmengsi 已提交
16

L
lvmengsi 已提交
17
## 模型简介
L
lvmengsi 已提交
18 19 20

本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\]

L
lvmengsi 已提交
21 22
注意:AttGAN和STGAN的网络结构中,判别器去掉了instance norm。

L
lvmengsi 已提交
23 24 25 26 27 28 29 30 31 32 33 34
图像生成模型库库的目录结构如下:
```
├── download.py 下载数据

├── data_reader.py 数据预处理

├── train.py 模型的训练入口

├── infer.py 模型的预测入口

├── trainer 不同模型的训练脚本
│   ├── CGAN.py Conditional GAN的训练脚本
L
lvmengsi 已提交
35
│   ├── ...
L
lvmengsi 已提交
36 37 38 39
│   ├── STGAN.py STGAN的训练脚本

├── network 不同模型的网络结构
│   ├── base_network.py GAN模型需要的公共基础网络结构
L
lvmengsi 已提交
40
│   ├── ...
L
lvmengsi 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53
│   ├── STGAN_network.py STGAN的网络结构

├── util 网络的基础配置和公共模块
│   ├── config.py 网络公用的基础配置
│   ├── utility.py 保存模型等网络公用的模块

├── scripts 多个模型的训练启动和测试启动示例
│   ├── run_....py 训练启动示例
│   ├── infer_....py 测试启动示例
│   ├── make_pair_data.py pix2pix GAN的数据list的生成脚本

```

L
lvmengsi 已提交
54
## 快速开始
L
lvmengsi 已提交
55 56

### 安装说明
L
lvmengsi 已提交
57 58 59 60
**安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle):**

在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.5或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。

L
lvmengsi 已提交
61 62 63 64 65
### 任务简介

Pix2Pix和CycleGAN采用cityscapes\[[10](#参考文献)\]数据集进行风格转换。
StarGAN,AttGAN和STGAN采用celeba\[[11](#参考文献)\]数据集进行属性迁移。

L
lvmengsi 已提交
66
### 数据准备
L
lvmengsi 已提交
67 68 69 70 71 72 73 74 75 76

模型库中提供了download.py数据下载脚本,该脚本支持下载MNIST数据集,CycleGAN和Pix2Pix所需要的数据集。使用以下命令下载数据:
    python download.py --dataset=mnist
通过指定dataset参数来下载相应的数据集。

StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)数据集可以自行下载。

**自定义数据集:**
用户可以使用自定义的数据集,只要设置成所对应的生成模型所需要的数据格式即可。

L
lvmengsi 已提交
77
注意: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里的make_pair_data.py来生成,可以使用以下命令来生成:
L
lvmengsi 已提交
78 79
  python scripts/make_pair_data.py \
    --direction=A2B
L
lvmengsi 已提交
80
用户可以通过设置`--direction`参数生成list文件,从而确保图像风格转变的方向。
L
lvmengsi 已提交
81

L
lvmengsi 已提交
82
### 模型训练
L
lvmengsi 已提交
83

L
lvmengsi 已提交
84
**开始训练:** 数据准备完毕后,可以通过一下方式启动训练:
L
lvmengsi 已提交
85

L
lvmengsi 已提交
86 87 88 89 90 91 92
    python train.py \
      --model_net=$(name_of_model) \
      --dataset=$(name_of_dataset) \
      --data_dir=$(path_to_data) \
      --train_list=$(path_to_train_data_list) \
      --test_list=$(path_to_test_data_list) \
      --batch_size=$(batch_size)
L
lvmengsi 已提交
93

L
lvmengsi 已提交
94 95 96 97 98 99 100 101 102
- 可选参数见:

    python train.py --help

- 每个GAN都给出了一份运行示例,放在scripts文件夹内,用户可以直接运行训练脚本快速开始训练。
- 用户可以通过设置`--model_net`参数来选择想要训练的模型,通过设置`--dataset`参数来选择训练所需要的数据集。

### 模型测试
模型测试是利用训练完成的生成模型进行图像生成。infer.py是主要的执行程序,调用示例如下:
L
lvmengsi 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123

执行以下命令得到CyleGAN的预测结果:

    python infer.py \
      --model_net=CycleGAN \
      --init_model=$(path_to_init_model) \
      --image_size=256 \
      --dataset_dir=$(path_to_data) \
      --input_style=$(A_or_B) \
      --net_G=$(generator_network) \
      --g_base_dims=$(base_dim_of_generator)

执行以下命令得到Pix2Pix的预测结果:

    python infer.py \
      --model_net=Pix2pix \
      --init_model=$(path_to_init_model) \
      --image_size=256 \
      --dataset_dir=$(path_to_data) \
      --net_G=$(generator_network)

L
lvmengsi 已提交
124
执行以下命令得到StarGAN,AttGAN或者STGAN的预测结果:
L
lvmengsi 已提交
125 126 127 128 129 130

    python infer.py \
      --model_net=$(StarGAN_or_AttGAN_or_STGAN) \
      --init_model=$(path_to_init_model)\
      --dataset_dir=$(path_to_data)

L
lvmengsi 已提交
131
Pix2Pix和CycleGAN的效果如图所示:
L
lvmengsi 已提交
132

L
lvmengsi 已提交
133 134 135
<p align="center">
<img src="images/pix2pix_cyclegan.png" width="650"/><br />
Pix2Pix和CycleGAN的效果图
L
lvmengsi 已提交
136
</p>
L
lvmengsi 已提交
137

L
lvmengsi 已提交
138

L
lvmengsi 已提交
139
StarGAN,AttGAN和STGAN的效果如图所示:
L
lvmengsi 已提交
140

L
lvmengsi 已提交
141 142 143
<p align="center">
<img src="images/female_stargan_attgan_stgan.png" width="650"/><br />
StarGAN,AttGAN和STGAN的效果图
L
lvmengsi 已提交
144
</p>
L
lvmengsi 已提交
145

L
lvmengsi 已提交
146

L
lvmengsi 已提交
147 148

- 每个GAN都给出了一份测试示例,放在scripts文件夹内,用户可以直接运行测试脚本得到测试结果。
L
lvmengsi 已提交
149

L
lvmengsi 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
**下载预训练模型:**

本示例提供以下预训练模型:

| Model| Pretrained model |
|:---|:---|
| Pix2Pix  | [Pix2Pix的预训练模型](https://paddle-gan-models.bj.bcebos.com/pix2pix_G.tar.gz)  |
| CycleGAN | [CycleGAN的预训练模型](https://paddle-gan-models.bj.bcebos.com/cyclegan_9blocks_G.tar.gz) |
| StarGAN  | [StarGAN的预训练模型](https://paddle-gan-models.bj.bcebos.com/stargan_G.tar.gz)  |
| AttGAN   | [AttGAN的预训练模型]()   |
| STGAN    | [STGAN的预训练模型](https://paddle-gan-models.bj.bcebos.com/stgan_G.tar.gz)    |


## 进阶使用
### 背景介绍

CGAN,条件生成对抗网络,一种带条件约束的GAN,使用额外信息对模型增加条件,可以指导数据生成过程。

DCGAN,深度卷积生成对抗网络,将GAN和卷积网络结合起来,以解决GAN训练不稳定的问题,利用卷积神经网络作为网络结构进行图像生成,可以得到更加丰富的层次表达。

Pix2Pix利用成对的图片进行图像翻译,即输入为同一张图片的两种不同风格,可用于进行风格迁移。

CycleGAN可以利用非成对的图片进行图像翻译,即输入为两种不同风格的不同图片,自动进行风格转换。

StarGAN多领域属性迁移,引入辅助分类帮助单个判别器判断多个属性,可用于人脸属性转换。

AttGAN利用分类损失和重构损失来保证改变特定的属性,可用于人脸特定属性转换。

STGAN只输入有变化的标签,引入GRU结构,更好的选择变化的属性,可用于人脸特定属性转换。

### 模型概览

- Pix2Pix由一个生成网络和一个判别网络组成。生成网络中编码部分的网络结构都是采用`convolution-batch norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-batch norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/Pix2pix_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。网络利用损失函数学习从输入图像到输出图像的映射,生成网络损失函数由CGAN的损失函数和L1损失函数组成,判别网络损失函数由CGAN的损失函数组成。生成器的网络结构如下图所示:

L
lvmengsi 已提交
184 185 186
<p align="center">
<img src="images/pix2pix_gen.png" width="550"/><br />
Pix2Pix生成网络结构图[5]
L
lvmengsi 已提交
187 188 189 190 191
</p>


- CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。生成网络中编码部分的网络结构都是采用`convolution-norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/CycleGAN_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。生成网络损失函数由CGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由CGAN的损失函数组成。

L
lvmengsi 已提交
192 193 194
<p align="center">
<img src="images/pix2pix_gen.png" width="550"/><br />
CycleGAN生成网络结构图[5]
L
lvmengsi 已提交
195 196 197 198 199
</p>


- StarGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分主要由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/StarGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。

L
lvmengsi 已提交
200 201 202 203
<p align="center">
<img src="images/stargan_gen.png" width=350 />
<img src="images/stargan_dis.png" width=400 /> <br />
StarGAN的生成网络结构[左]和判别网络结构[右] [7]
L
lvmengsi 已提交
204 205 206
</p>


L
lvmengsi 已提交
207

L
lvmengsi 已提交
208 209
- AttGAN中生成网络的编码部分主要由`convolution-instance norm-ReLU`组成,解码部分由`transpose convolution-norm-ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/AttGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。

L
lvmengsi 已提交
210 211 212
<p align="center">
<img src="images/attgan_net.png" width=800 /> <br />
AttGAN的网络结构[8]
L
lvmengsi 已提交
213 214 215 216 217
</p>


- STGAN中生成网络再编码器和解码器之间加入Selective Transfer Units\(STU\),有选择的转换编码网络,从而更好的适配解码网络。生成网络中的编码网络主要由`convolution-instance norm-ReLU`组成,解码网络主要由`transpose convolution-norm-leaky_ReLU`组成,判别网络主要由`convolution-leaky_ReLU`组成,详细网络结构可以查看`network/STGAN_network.py`文件。生成网络的损失函数是由CGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。

L
lvmengsi 已提交
218 219 220
<p align="center">
<img src="images/stgan_net.png" width=800 /> <br />
STGAN的网络结构[9]
L
lvmengsi 已提交
221 222 223 224 225 226 227 228
</p>


注意:网络结构中的norm指的是用户可以选用batch norm或者instance norm来搭建自己的网络。


## FAQ

L
lvmengsi 已提交
229 230
**Q:** StarGAN/AttGAN/STGAN中属性没有变化,为什么?   
**A:** 查看是否所有的标签都转换对了。
L
lvmengsi 已提交
231

L
lvmengsi 已提交
232 233 234
**Q:** 预测结果不正常,是怎么回事?  
**A:** 某些GAN预测的时候batch_norm的设置需要和训练的时候行为一致,查看模型库中相应的GAN中预测时batch_norm的行为和自己模型中的预测时batch_norm的
行为是否一致。 
L
lvmengsi 已提交
235

L
lvmengsi 已提交
236 237 238
**Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢?  
**A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,可以参考模型库中预测代码的StarGAN的标签设置。
L
lvmengsi 已提交
239 240 241


## 参考论文
L
lvmengsi 已提交
242
[1] [Goodfellow, Ian J.; Pouget-Abadie, Jean; Mirza, Mehdi; Xu, Bing; Warde-Farley, David; Ozair, Sherjil; Courville, Aaron; Bengio, Yoshua. Generative Adversarial Networks. 2014. arXiv:1406.2661 [stat.ML].](https://arxiv.org/abs/1406.2661)
L
lvmengsi 已提交
243

L
lvmengsi 已提交
244
[2] [生成对抗网络](https://zh.wikipedia.org/wiki/生成对抗网络)
L
lvmengsi 已提交
245

L
lvmengsi 已提交
246
[3] [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
L
lvmengsi 已提交
247

L
lvmengsi 已提交
248
[4] [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434)
L
lvmengsi 已提交
249

L
lvmengsi 已提交
250
[5] [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004)
L
lvmengsi 已提交
251

L
lvmengsi 已提交
252
[6] [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)
L
lvmengsi 已提交
253

L
lvmengsi 已提交
254
[7] [StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation](https://arxiv.org/abs/1711.09020)
L
lvmengsi 已提交
255

L
lvmengsi 已提交
256
[8] [AttGAN: Facial Attribute Editing by Only Changing What You Want](https://arxiv.org/abs/1711.10678)
L
lvmengsi 已提交
257

L
lvmengsi 已提交
258
[9] [STGAN: A Unified Selective Transfer Network for Arbitrary Image Attribute Editing](https://arxiv.org/abs/1904.09709)
L
lvmengsi 已提交
259 260 261 262 263 264 265 266

[10] [The Cityscapes Dataset for Semantic Urban Scene Understanding](https://arxiv.org/abs/1604.01685)

[11] [Deep Learning Face Attributes in the Wild](https://arxiv.org/abs/1411.7766)


## 版本更新

L
lvmengsi 已提交
267
- 6/2019 新增CGAN, DCGAN, Pix2Pix, CycleGAN,StarGAN, AttGAN, STGAN
L
lvmengsi 已提交
268 269 270 271 272 273 274

## 作者
- [ceci3](https://github.com/ceci3)
- [zhumanyu](https://github.com/zhumanyu)

## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。