提交 fce00da7 编写于 作者: W wangyang59 提交者: GitHub

Merge pull request #2 from wangkuiyi/yi-gan

Yi's change to GAN tutorial
# 对抗式生成网络 # 对抗式生成网络
## 背景介绍 ## 背景介绍
本章我们介绍对抗式生成网络,也称为Generative Adversarial Network(GAN) \[[1](#参考文献)\]。对抗式生成网络是生成模型 (generative model) 的一种,可以用非监督学习的办法来学习输入数据的分布,从而能达到产生和输入数据拥有同样概率分布的人造数据。这样的学习能力可以帮助机器完成图片自动生成、图像去噪、缺失图像补全和图像超分辨生成等工作。
深度学习现有的方法大致可以分为两大类,判别模型(discriminative model)和生成模型(generative model) 本章我们介绍对抗式生成网络,也称为Generative Adversarial Network (GAN) \[[1](#参考文献)\]。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型(generative model),我们引入一个判别式神经元网络模型来构造优化目标函数。实验证明,在图像自动生成、图像去噪、和缺失图像补全等应用里,这种方法可以训练处一个能更逼近训练数据分布的生成式模型
判别模型是在监督学习的条件下,把高维数据映射到一种低维空间表示(representation)里来进行分类(可参见前面几章的介绍),它直接对条件概率P(y|x)建模。像我们的前八章,都是判别模型。但用这种方法学到的表示一般只是对那一种目标任务有效果,而不能很好的转移到别的任务。同时监督学习的训练需要大量标记好的数据,很多时候不是很容易得到 到目前为止,大部分取得好的应用效果的神经元网络模型都是有监督训练(supervised learning)的判别式模型(discriminative models),包括图像识别中使用的convolutional networks和在语音识别中使用的connectionist temporal classification (CTC) networks。在这些例子里,训练数据 X 都是带有标签 y 的——每张图片附带了一个或者多个tag,每段语音附带了一段对应的文本;而模型的输入是 X,输出是 y,训练得到的模型表示从X到y的映射函数 y=f(X)
生成模型在监督学习和非监督学习的条件下都可以应用。在监督学习的条件下,生成模型是直接对联合概率P(X,Y)建模。在非监督学习的条件下,生成模型是对P(X)进行建模。生成模型背后的基本想法是,如果一个模型它能够生成和真实数据非常相近的数据,那么很可能它就学到了对于这种数据的一种很有效的表示。生成模型另一些实际用途包括,图像去噪,缺失图像补全,图像超分辨生成等等。在标记数据不够的时候,还可以用生成模型生成的数据来预训练模型 和判别式神经元网络模型相对的一类模型是生成式模型(generative models)。它们通常是通过非监督训练(unsupervised learning)来得到的。这类模型的训练数据里只有 X,没有y。训练的目标是希望模型能蕴含训练数据的统计分布信息,从而可以从训练好的模型里产生出新的、在训练数据里没有出现过的新数据 x'
生成模型一个重要的研究方向是图片生成。相比于生成文字,由于图片数据的维度更大并且数值是连续的,所以生成起来难度更大。关于图片生成的研究已经有比较久的历史,之前的方法有,受限玻尔兹曼机(Restricted Boltzmann Machine)\[[4](#参考文献)\],深度玻尔兹曼机(Deep Boltzmann Machine)\[[5](#参考文献)\],神经自回归分布估计(Neural Autoregressive Distribution Estimator)\[[6](#参考文献)\]等。但它们都无法生成看起来很真实的图片 一些为人熟知的生成模型的例子包括受限玻尔兹曼机(Restricted Boltzmann Machine)\[[4](#参考文献)\],深度玻尔兹曼机(Deep Boltzmann Machine)\[[5](#参考文献)\],神经自回归分布估计(Neural Autoregressive Distribution Estimator)\[[6](#参考文献)\]
近年来由于深度学习的发展,出现了一些更有效的图片生成模型,一种是变分自编码器(variational autoencoder)\[[3](#参考文献)\],它是在概率图模型(probabilistic graphical model)的框架下面搭建了一个生成模型,对数据有完整的概率描述,训练时是通过调节参数来最大化数据的概率。用这种方法产生的图片,虽然所对应的概率高,但很多时候看起来都比较模糊。另一种是像素循环神经网络(Pixel Recurrent Neural Network)\[[7](#参考文献)\],它是通过根据周围的像素来一个像素一个像素的生成图片,但这种方法生成的图片在全局看来会不太一致。为了解决这些问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。 近年出现了一些专门用来生成图像的模型,一种是变分自编码器(variational autoencoder)\[[3](#参考文献)\],它是在概率图模型(probabilistic graphical model)的框架下面搭建了一个生成模型,对数据有完整的概率描述,训练时是通过调节参数来最大化数据的概率。用这种方法产生的图片,虽然似然(likelihood)比较高,但经常看起来比较模糊。另一种是像素循环神经网络(Pixel Recurrent Neural Network)\[[7](#参考文献)\],它是通过根据周围的像素来一个像素一个像素的生成图片,但这种方法生成的图片在全局看来会不太一致。为了解决这些问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。
在本章里,我们展对抗式生产网络的细节,以及如何用PaddlePaddle训练一个GAN模型。 本文介绍如何训练一个产生式神经元网络模型,它的输入是一个随机生成的向量(相当于不需要任何有意义的输入),而输出是一幅图像,其中有一个数字。换句话说,我们训练一个会写字(阿拉伯数字)的神经元网络模型。它“写”的一些数字如下图:
## 效果展示
一个简单的例子是训练对抗式生成网络,使其学习产生MNIST手写数字的图片。由训练好的GAN模型产生的手写数字图片的例子画在图1中。
<p align="center"> <p align="center">
<img src="./mnist_sample.png" width="300" height="300"><br/> <img src="./mnist_sample.png" width="300" height="300"><br/>
图1. GAN生成的MNIST例图 图1. GAN生成的MNIST例图
</p> </p>
## 模型概览 现实中成功使用的生成式神经元网络模型往往接受有意义的输入。比如可能接受一幅低分辨率的图像,输出对应的高分辨率图像。这样的模型被称为 conditional GAN这过程实际上是从大量数据学习得到模型,或者说归纳得到知识,然后用这些知识来补足图像的分辨率。
对抗式生成网络的原理示意图在图2中画出,它由两部分组成:一个生成器(Generator)G 和一个分类器(Discriminator, 也称判别器)D,两者都是有多层神经网络构成的。生成器的输入是一个多维的已知概率分布的噪音 z(噪音的概率分布不取决于待生成样本,如可以服从正态分布),通过神经网络变换,输出伪样本。分类器输的输入是真样本和伪样本,输出为分类结果为真样本和伪样本的概率。训练时生成器和分类器处于相互竞争对抗状态,生成器会尽量生成和真样本相近的伪样本让分类器无法分辨真伪,而分类器则会尽力去分辨伪样本。具体的损失函数如下:
$$\min_G\max_D \text{Loss} = \min_G\max_D \frac{1}{m}\sum_{i=1}^m[\log D(x^i) + log(1-D(G(z^i)))]$$ ## 传统训练方式和对抗式训练
其中$x$是真实数据,$z$是已知概率分布的噪音。所以这个损失函数所代表的意义就是真实数据被分类为真的概率加上伪数据被分类为假的概率。分类器 D 目标是增加这个函数值,故公式里为max,而生成器 G 目标是减少这个函数值,故公式里为min 因为神经元网络是一个有向图,总是有输入和输出的。当我们用无监督学习方式来训练一个神经元网络,用于描述训练数据分布的时候,一个通常的学习目标是估计一组参数,使得输出和输入很接近 —— 或者说输入是什么输出就是什么。很多早期的生成式神经元网络模型,包括受限波尔茨曼机和 autoencoder 都是这么训练的。这种情况下优化目标经常是最小化输出和输入的差别
<p align="center"> <p align="center">
<img src="./gan.png" width="500" height="300"><br/> <img src="./gan.png" width="500" height="300"><br/>
...@@ -36,9 +32,25 @@ $$\min_G\max_D \text{Loss} = \min_G\max_D \frac{1}{m}\sum_{i=1}^m[\log D(x^i) + ...@@ -36,9 +32,25 @@ $$\min_G\max_D \text{Loss} = \min_G\max_D \frac{1}{m}\sum_{i=1}^m[\log D(x^i) +
<a href="https://ishmaelbelghazi.github.io/ALI/">figure credit</a> <a href="https://ishmaelbelghazi.github.io/ALI/">figure credit</a>
</p> </p>
训练时,生成器和分类器会轮流通过随机梯度下降算法更新参数。生成器的目标函数是让自己产生的样本被分类器分类为真,而分类器的目标函数则是正确的区分真伪样本。当对抗式生成模型训练收敛到平衡态的时候,生成器会把输入的噪音分布转化成真的样本数据分布,而分类器则完全无法分辨真伪图片。 对抗式训练里,我们用一个判别式模型 D 辅助构造优化目标函数,来训练一个生成式模型 G。如图2所示。具体训练流程是不断交替执行如下两步:
1. 更新模型 D:
1. 固定G的参数不变,对于一组随机输入,得到一组(产生式)输出,$X_f$,并且将其label成“假”。
1. 从训练数据 X 采样一组 $X_r$,并且label为“真”。
1. 用这两组数据更新模型 D,从而使D能够分辨G产生的数据和真实训练数据。
1. 更新模型 G:
1. 把G的输出和D的输入连接起来,得到一个网路。
1. 给G一组随机输入,期待G的输出让D认为像是“真”的。
1. 在D的输出端,优化目标是通过更新G的参数来最小化D的输出和“真”的差别。
上述方法实际上在优化如下目标:
$$\min_G \max_D \frac{1}{N}\sum_{i=1}^N[\log D(x^i) + \log(1-D(G(z^i)))]$$
在最早的对抗式生成网络的论文中,生成器和分类器用的都是全联接层。在附带的代码gan_conf.py中,我们实现了一个类似的结构。生成器和分类器都是由三层全联接层构成,并且在某些全联接层后面加入了批标准化层(batch normalization)。所用网络结构在图3中给出。生成器的损失函数是其所生成的伪样本$x'$被判别器判定为真的概率,而判别器的损失函数是伪样本$x'$被判定为假的概率加上真样本$x$被判别为真的概率。 其中$x$是真实数据,$z$是随机产生的输入,$N$是训练数据的数量。这个损失函数的意思是:真实数据被分类为真的概率加上伪数据被分类为假的概率。因为上述两步交替优化G生成的结果的仿真程度(看起来像x)和D分辨G的生成结果和x的能力,所以这个方法被称为对抗(adversarial)方法。
在最早的对抗式生成网络的论文中,生成器和分类器用的都是全联接层。在附带的代码[`gan_conf.py`](./gan_conf.py)中,我们实现了一个类似的结构。G和D是由三层全联接层构成,并且在某些全联接层后面加入了批标准化层(batch normalization)。所用网络结构在图3中给出。
<p align="center"> <p align="center">
<img src="./gan_conf_graph.png" width="700" height="400"><br/> <img src="./gan_conf_graph.png" width="700" height="400"><br/>
...@@ -54,12 +66,9 @@ $$\min_G\max_D \text{Loss} = \min_G\max_D \frac{1}{m}\sum_{i=1}^m[\log D(x^i) + ...@@ -54,12 +66,9 @@ $$\min_G\max_D \text{Loss} = \min_G\max_D \frac{1}{m}\sum_{i=1}^m[\log D(x^i) +
</p> </p>
## 数据准备 ## 数据
### 数据介绍与下载
这章会用到两种数据,一种是简单的人造数据,一种是图片。
人造数据是二维0到1之间的均匀分布,由下面的代码生成(numpy.random.rand会生成0-1均匀分布随机数): 这章会用到两种数据,一种是G的随机输入,另一种是来自MNIST数据集的图片,其中一张是人类手写的一个数字。随机输入数据的生成方式如下:
```python ```python
# 合成2-D均匀分布数据 gan_trainer.py:114 # 合成2-D均匀分布数据 gan_trainer.py:114
...@@ -68,19 +77,27 @@ def load_uniform_data(): ...@@ -68,19 +77,27 @@ def load_uniform_data():
return data return data
``` ```
图片数据是MNIST手写数字和CIFAR-10,可由下面的代码下载: MNIST数据可以通过执行[get_mnist_data.sh](./data/get_mnist_data.sh)下载:
```bash ```bash
$cd data/ $cd data/
$./get_mnist_data.sh $./get_mnist_data.sh
```
其实只需要换一种图像数据集,这个例子即可训练G来生成对应的类似图像。比如Cifar-10数据集可由执行[download_cifa.sh](./data/download_cifa.sh)下载:
```bash
$cd data/
$./download_cifar.sh $./download_cifar.sh
``` ```
## 模型配置说明 ## 模型配置
由于对抗式生产网络涉及到多个神经网络,所以必须用paddle Python API来训练。下面的介绍也可以部分的拿来当作paddle Python API的使用说明。
由于对抗式生产网络涉及到多个神经网络,所以必须用PaddlePaddle Python API来训练。下面的介绍也可以部分的拿来当作PaddlePaddle Python API的使用说明。
### 数据定义 ### 数据定义
这里数据没有通过dataprovider提供,而是在gan_trainer.py里面直接产生data_batch并以Arguments的形式提供给trainer。
这里数据没有通过data provider提供,而是在`gan_trainer.py`里面直接产生`data_batch`并以`Arguments`的形式提供给trainer。
```python ```python
def prepare_generator_data_batch(batch_size, noise): def prepare_generator_data_batch(batch_size, noise):
...@@ -103,7 +120,7 @@ gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) ...@@ -103,7 +120,7 @@ gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
### 算法配置 ### 算法配置
在这里,我们指定了模型的训练参数, 选择学习率和batch size。这里的beta1参数比默认值0.9小很多是为了使学习的过程更稳定。 在这里,我们指定了模型的训练参数, 选择learning rate和batch size。这里的`beta1`参数比默认值0.9小很多是为了使学习的过程更稳定。
```python ```python
settings( settings(
...@@ -114,7 +131,8 @@ settings( ...@@ -114,7 +131,8 @@ settings(
``` ```
### 模型结构 ### 模型结构
在文件gan_conf.py当中我们定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是分类器,**generator** 是生成器,**generator_training** 是生成器加分类器因为训练生成器时需要用到分类器提供目标函数。这个对应关系在下面这段代码中定义:
在文件`gan_conf.py`当中我们定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是分类器,**generator** 是生成器,**generator_training** 是生成器加分类器因为训练生成器时需要用到分类器提供目标函数。这个对应关系在下面这段代码中定义:
```python ```python
if is_generator_training: if is_generator_training:
...@@ -139,14 +157,16 @@ if is_generator: ...@@ -139,14 +157,16 @@ if is_generator:
outputs(generator(noise)) outputs(generator(noise))
``` ```
##训练模型 ## 训练模型
用MNIST手写数字图片训练对抗式生成网络可以用如下的命令: 用MNIST手写数字图片训练对抗式生成网络可以用如下的命令:
```bash ```bash
$python gan_trainer.py -d mnist --use_gpu 1 $python gan_trainer.py -d mnist --use_gpu 1
``` ```
训练中打印的日志信息如下: 训练中打印的日志信息大致如下:
``` ```
d_pos_loss is 0.681067 d_neg_loss is 0.704936 d_pos_loss is 0.681067 d_neg_loss is 0.704936
d_loss is 0.693001151085 g_loss is 0.681496 d_loss is 0.693001151085 g_loss is 0.681496
...@@ -160,9 +180,16 @@ d_loss is 0.680106401443 g_loss is 0.671118 ...@@ -160,9 +180,16 @@ d_loss is 0.680106401443 g_loss is 0.671118
I0105 17:16:37.172737 20517 TrainerInternal.cpp:165] Batch=100 samples=12800 AvgCost=0.687359 CurrentCost=0.687359 Eval: discriminator_training_error=0.438359 CurrentEval: discriminator_training_error=0.438359 I0105 17:16:37.172737 20517 TrainerInternal.cpp:165] Batch=100 samples=12800 AvgCost=0.687359 CurrentCost=0.687359 Eval: discriminator_training_error=0.438359 CurrentEval: discriminator_training_error=0.438359
``` ```
其中d_pos_loss是判别器对于真实数据判别真的负对数概率,d_neg_loss是判别器对于伪数据判别为假的负对数概率,d_loss是这两者的平均值。g_loss是伪数据被判别器判别为真的负对数概率。对于对抗式生成网络来说,最好的训练情况是D和G的能力比较相近,也就是d_loss和g_loss在训练的前几个pass中数值比较接近(-log(0.5) = 0.693)。由于G和D是轮流训练,所以它们各自每过100个batch,都会打印各自的训练信息。 其中`d_pos_loss`是判别器对于真实数据判别真的负对数概率,`d_neg_loss`是判别器对于伪数据判别为假的负对数概率,`d_loss`是这两者的平均值。`g_loss`是伪数据被判别器判别为真的负对数概率。对于对抗式生成网络来说,最好的训练情况是D和G的能力比较相近,也就是`d_loss``g_loss`在训练的前几个pass中数值比较接近(-log(0.5) = 0.693)。由于G和D是轮流训练,所以它们各自每过100个batch,都会打印各自的训练信息。
为了能够训练在gan_conf.py中定义的网络,我们需要如下几个步骤:
1. 初始化Paddle环境,
1. 解析设置,
1. 由设置创造GradientMachine以及由GradientMachine创造trainer。
这几步分别由下面几段代码实现:
为了能够训练在gan_conf.py中定义的网络,我们需要如下几个步骤:初始化Paddle环境,解析设置,由设置创造GradientMachine以及由GradientMachine创造trainer。这几步分别由下面几段代码实现:
```python ```python
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
...@@ -189,7 +216,7 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) ...@@ -189,7 +216,7 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
``` ```
为了能够平衡生成器和分类器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过GradientMachine的forward pass来计算。 为了能够平衡生成器和分类器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过调用`GradientMachine``forward`方法来计算。
```python ```python
def get_training_loss(training_machine, inputs): def get_training_loss(training_machine, inputs):
...@@ -212,7 +239,8 @@ copy_shared_parameters(gen_training_machine, generator_machine) ...@@ -212,7 +239,8 @@ copy_shared_parameters(gen_training_machine, generator_machine)
``` ```
## 应用模型 ## 应用模型
图片由训练好的生成器生成。以下的代码将噪音z输入到生成器 G 当中,通过向前传递得到生成的图片。
图片由训练好的生成器生成。以下的代码将随机向量输入到模型 G,通过向前传递得到生成的图片。
```python ```python
# 噪音z是多维正态分布 # 噪音z是多维正态分布
...@@ -233,10 +261,12 @@ fake_samples = get_fake_samples(generator_machine, batch_size, noise) ...@@ -233,10 +261,12 @@ fake_samples = get_fake_samples(generator_machine, batch_size, noise)
``` ```
## 总结 ## 总结
本章中,我们介绍了对抗式生成网络的基本概念,训练方法以及如何用Paddle来实现。对抗式生成网络是现有生成模型当中非常重要的一种,它可以利用大量无标记数据来进行非监督学习,以寄希望能够得到对于复杂高维数据的一般有效的表示。
本章中,我们介绍了对抗式生成网络的基本概念,训练方法以及如何用PaddlePaddle来训练一个简单的图像生成模型。对抗式生成网络是一种新的训练生成模型的有效方法,我们期待看到它的更有意思的应用场景。
## 参考文献 ## 参考文献
1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. [Generative adversarial nets](https://arxiv.org/pdf/1406.2661v1.pdf)[C] Advances in Neural Information Processing Systems. 2014 1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. [Generative adversarial nets](https://arxiv.org/pdf/1406.2661v1.pdf)[C] Advances in Neural Information Processing Systems. 2014
2. Radford A, Metz L, Chintala S. [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2.pdf)[C] arXiv preprint arXiv:1511.06434. 2015 2. Radford A, Metz L, Chintala S. [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2.pdf)[C] arXiv preprint arXiv:1511.06434. 2015
3. Kingma D.P. and Welling M. [Auto-encoding variational bayes](https://arxiv.org/pdf/1312.6114v10.pdf)[C] arXiv preprint arXiv:1312.6114. 2013 3. Kingma D.P. and Welling M. [Auto-encoding variational bayes](https://arxiv.org/pdf/1312.6114v10.pdf)[C] arXiv preprint arXiv:1312.6114. 2013
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册