提交 c7932ba4 编写于 作者: Y Yang Wang 提交者: wangyang59

finished first draft of gan chapter

上级 472dec1f
# 对抗式生成网络 # 对抗式生成网络
## 背景介绍 ## 背景介绍
本章我们介绍对抗式生成网络,也称为Generative Adversarial Network(GAN)。对抗式生成网络是生成模型的一种,可以用非监督学习的办法来学习输入数据的分布,从而能达到产生和输入数据拥有同样概率分布的人造数据。 本章我们介绍对抗式生成网络,也称为Generative Adversarial Network(GAN) \[[1](#参考文献)\]。对抗式生成网络是生成模型 (generative model) 的一种,可以用非监督学习的办法来学习输入数据的分布,从而能达到产生和输入数据拥有同样概率分布的人造数据。
现在大部分利用深度学习成功的例子都是在监督学习的条件下,把高维数据映射到一种低维空间表示(representation)里来进行分类(可参见前面几章的介绍)。这种方法也叫区分模型(discriminative model)。但用这种方法学到的表示一般只是对那一种目标任务有效果,而不能很好的转移到别的任务。同时监督学习的训练需要大量标记好的数据,很多时候不是很容易得到。
所以为了能够从大量无标记数据里学到通用有效的表示,人们发明了另一种模型叫作生成模型。这个方法背后的基本想法是,如果一个模型它能够生成和真实数据非常相近的数据,那么很可能它就学到了对于这种数据的一种很有效的表示。生成模型另一些实际用途包括,图像去噪,缺失图像补全,图像超分辨生成等等。在标记数据不够的时候,还可以用生成模型生成的数据来预训练模型。
现在常用的生成模型大致有两种类型,一种是变分自编码器(variational autoencoder)\[[3](#参考文献)\],它是在概率图模型(probabilistic graphical model)的框架下面搭建了一个生成模型,对数据有完整的概率描述,训练时是通过调节参数来最大化数据的概率。用这种方法产生的图片,虽然所对应的概率高,但很多时候看起来都比较模糊。为了解决这个问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。
在本章里,我们展对抗式生产网络的细节,以及如何用PaddlePaddle训练一个GAN模型。 在本章里,我们展对抗式生产网络的细节,以及如何用PaddlePaddle训练一个GAN模型。
...@@ -14,17 +20,28 @@ ...@@ -14,17 +20,28 @@
</p> </p>
## 模型概览 ## 模型概览
对抗式生成网络的大致结构在图2中画出,它由两部分组成:一个生成器(G)和一个分别器(D),两者都是有多层神经网络构成的。生成器的输入是一个多维的已知概率分布的噪音(z),通过神经网络变换,输出伪样本。分别器输的输入是真样本和伪样本,输出为判断样本为真样本的概率。训练时生成器和分别器处于相互竞争对抗状态,生成器会尽量生成和真样本相近的伪样本让分别器无法分辨真伪,而分别器则会尽力去分辨伪样本。具体的损失函数如下: 对抗式生成网络的大致结构在图2中画出,它由两部分组成:一个生成器(Generator)G 和一个分类器(Discriminator, 也称判别器)D,两者都是有多层神经网络构成的。生成器的输入是一个多维的已知概率分布的噪音 z,通过神经网络变换,输出伪样本。分类器输的输入是真样本和伪样本,输出为分类结果为真样本和伪样本的概率。训练时生成器和分类器处于相互竞争对抗状态,生成器会尽量生成和真样本相近的伪样本让分类器无法分辨真伪,而分类器则会尽力去分辨伪样本。具体的损失函数如下:
$$\min_G\max_D \text{Loss} = \min_G\max_D \frac{1}{m}\sum_{i=1}^m[\log D(x^i) + log(1-D(G(z^i)))]$$ $$\min_G\max_D \text{Loss} = \min_G\max_D \frac{1}{m}\sum_{i=1}^m[\log D(x^i) + log(1-D(G(z^i)))]$$
这个损失函数里面$x$是真实数据,$z$是已知概率分布的噪音。所以这个损失函数所代表的意义就是真实数据被分类为真的概率加上伪数据被分类为假的概率。分类器 D 目标是增加这个函数值,故公式里为max,而生成器 G 目标是减少这个函数值,故公式里为min。
<p align="center"> <p align="center">
<img src="./gan.png" width="500" height="300"><br/> <img src="./gan.png" width="500" height="300"><br/>
图2. GAN模型结构 图2. GAN模型结构
<a href="https://ishmaelbelghazi.github.io/ALI/">figure credit</a> <a href="https://ishmaelbelghazi.github.io/ALI/">figure credit</a>
</p> </p>
训练时,生成器和分别器会轮流通过随机梯度下降算法更新参数。生成器的目标函数是让自己产生的样本被分别器分类为真,而分别器的目标函数则是正确的区分真伪样本。当对抗式生成模型训练收敛到平衡态的时候,生成器会把输入的噪音分布转化成真的样本数据分布,而分别器则完全无法分辨真伪图片。 训练时,生成器和分类器会轮流通过随机梯度下降算法更新参数。生成器的目标函数是让自己产生的样本被分类器分类为真,而分类器的目标函数则是正确的区分真伪样本。当对抗式生成模型训练收敛到平衡态的时候,生成器会把输入的噪音分布转化成真的样本数据分布,而分类器则完全无法分辨真伪图片。
在最早的对抗式生成网络的论文中,生成器和分类器用的都是全联接层,所以没有办法很好的生成图片数据,也没有办法做的很深。所以在随后的论文中,人们提出了深度卷积对抗式生成网络(deep convolutional generative adversarial network or DCGAN)\[[2](#参考文献)\]。在DCGAN中,生成器 G 是由多个卷积转置层(transposed convolution)组成的,这样可以用更少的参数来生成质量更高的图片。具体网络结果可参见图3。
<p align="center">
<img src="./dcgan.png" width="700" height="300"><br/>
图3. DCGAN生成器模型结构
<a href="https://arxiv.org/pdf/1511.06434v2.pdf/">figure credit</a>
</p>
## 数据准备 ## 数据准备
...@@ -51,7 +68,7 @@ $./get_mnist_data.sh ...@@ -51,7 +68,7 @@ $./get_mnist_data.sh
由于对抗式生产网络涉及到多个神经网络,所以必须用paddle Python API来训练。下面的介绍也可以部分的拿来当作paddle Python API的使用说明。 由于对抗式生产网络涉及到多个神经网络,所以必须用paddle Python API来训练。下面的介绍也可以部分的拿来当作paddle Python API的使用说明。
### 模型结构 ### 模型结构
在文件gan_conf.py当中我们定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是分别器,**generator** 是生成器,**generator_training** 是生成器加分别器因为训练生成器时需要用到分别器提供目标函数。这个对应关系在下面这段代码中定义: 在文件gan_conf.py当中我们定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是分类器,**generator** 是生成器,**generator_training** 是生成器加分类器因为训练生成器时需要用到分类器提供目标函数。这个对应关系在下面这段代码中定义:
```python ```python
if is_generator_training: if is_generator_training:
...@@ -101,7 +118,7 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) ...@@ -101,7 +118,7 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
``` ```
为了能够平衡生成器和分器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过GradientMachine的forward pass来计算。 为了能够平衡生成器和分器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过GradientMachine的forward pass来计算。
```python ```python
def get_training_loss(training_machine, inputs): def get_training_loss(training_machine, inputs):
...@@ -124,10 +141,20 @@ copy_shared_parameters(gen_training_machine, generator_machine) ...@@ -124,10 +141,20 @@ copy_shared_parameters(gen_training_machine, generator_machine)
``` ```
### 数据定义 ### 数据定义
这里数据没有通过dataprovider提供,而是在gan_trainer.py里面直接产生data_batch并提供给trainer。 这里数据没有通过dataprovider提供,而是在gan_trainer.py里面直接产生data_batch并以Arguments的形式提供给trainer。
```python ```python
code to be inserted def prepare_generator_data_batch(batch_size, noise):
label = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label))
return inputs
Create data_batch for generator
data_batch_gen = prepare_generator_data_batch(batch_size, noise)
# Feed data_batch_gen into generator trainer
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
``` ```
### 算法配置 ### 算法配置
...@@ -149,13 +176,27 @@ settings( ...@@ -149,13 +176,27 @@ settings(
$python gan_trainer.py -d mnist --useGpu 1 $python gan_trainer.py -d mnist --useGpu 1
``` ```
## 应用模型
图片由训练好的生成器生成。以下的代码将噪音z输入到生成器 G 当中,通过向前传递得到生成的图片。
```python
def get_fake_samples(generator_machine, batch_size, noise):
gen_inputs = api.Arguments.createArguments(1)
gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
gen_outputs = api.Arguments.createArguments(0)
generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST)
fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat()
return fake_samples
# At the end of each pass, save the generated samples/images
fake_samples = get_fake_samples(generator_machine, batch_size, noise)
```
## 总结 ## 总结
本章中, 本章中,我们介绍了对抗式生成网络的基本概念,训练方法以及如何用Paddle来实现。对抗式生成网络是现有生成模型当中非常重要的一种,它可以利用大量无标记数据来进行非监督学习,以寄希望能够得到对于复杂高维数据的一般有效的表示。
## 参考文献 ## 参考文献
1. Bengio Y, Ducharme R, Vincent P, et al. [A neural probabilistic language model](http://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf)[J]. journal of machine learning research, 2003, 3(Feb): 1137-1155. 1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. [Generative adversarial nets](https://arxiv.org/pdf/1406.2661v1.pdf)[C] Advances in Neural Information Processing Systems. 2014
2. Mikolov T, Sutskever I, Chen K, et al. [Distributed representations of words and phrases and their compositionality](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)[C]//Advances in neural information processing systems. 2013: 3111-3119. 2. Radford A, Metz L, Chintala S. [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2.pdf)[C] arXiv preprint arXiv:1511.06434. 2015
3. Mikolov T, Kombrink S, Deoras A, et al. [Rnnlm-recurrent neural network language modeling toolkit](http://www.fit.vutbr.cz/~imikolov/rnnlm/rnnlm-demo.pdf)[C]//Proc. of the 2011 ASRU Workshop. 2011: 196-201. 3. Kingma D.P. and Welling M. [Auto-encoding variational bayes](https://arxiv.org/pdf/1312.6114v10.pdf)[C] arXiv preprint arXiv:1312.6114. 2013
4. Mikolov T, Chen K, Corrado G, et al. [Efficient estimation of word representations in vector space\[J\]](https://arxiv.org/pdf/1301.3781.pdf). arXiv preprint arXiv:1301.3781, 2013. \ No newline at end of file
<!-- 5. Mikolov T, Karafiát M, Burget L, et al. [Recurrent neural network based language model](http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf)[C]//Interspeech. 2010, 2: 3. -->
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册