提交 8ffcfc80 编写于 作者: Y Yi Wang

Update README.md

上级 c00f0bf7
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
## 背景介绍 ## 背景介绍
本章我们介绍对抗式生成网络,也称为Generative Adversarial Network (GAN) \[[1](#参考文献)\]。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型(generative model),我们引入一个分类神经元网络模型来构造优化目标函数。实验证明,在图像自动生成、图像去噪、和确实图像不全等应用里,这种方法可以更容易地得到一个能更好逼近训练数据分布的生成式模型。 本章我们介绍对抗式生成网络,也称为Generative Adversarial Network (GAN) \[[1](#参考文献)\]。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型(generative model),我们引入一个判别式神经元网络模型来构造优化目标函数。实验证明,在图像自动生成、图像去噪、和缺失图像补全等应用里,这种方法可以训练处一个能更逼近训练数据分布的生成式模型。
到目前为止,大部分在应用中取得好效果的神经元网络模型都是有监督训练(supervised learning)的判别式模型(discriminative models),包括图像识别中使用的convolutional networks和在语音识别中使用的connectionist temporal classification (CTC) networks。在这些例子里,训练数据 X 都是带有标签 y 的——每张图片附带了一个或者多个tag,每段语音附带了一段对应的文本;而模型的输入是 X,输出是 y,表示从X到y的映射函数 y=f(X)。 到目前为止,大部分取得好的应用效果的神经元网络模型都是有监督训练(supervised learning)的判别式模型(discriminative models),包括图像识别中使用的convolutional networks和在语音识别中使用的connectionist temporal classification (CTC) networks。在这些例子里,训练数据 X 都是带有标签 y 的——每张图片附带了一个或者多个tag,每段语音附带了一段对应的文本;而模型的输入是 X,输出是 y,训练得到的模型表示从X到y的映射函数 y=f(X)。
和判别式神经元网络模型相对的一类模型是生成式模型(generative models)。它们通常是通过非监督训练(unsupervised learning)来得到的。这类模型的训练数据里只有 X,没有y。训练的目标是希望模型能包含训练数据的分布,从而可以从训练好的模型里产生出新的、在训练数据里没有出现过的数据 x'。 和判别式神经元网络模型相对的一类模型是生成式模型(generative models)。它们通常是通过非监督训练(unsupervised learning)来得到的。这类模型的训练数据里只有 X,没有y。训练的目标是希望模型能蕴含训练数据的统计分布信息,从而可以从训练好的模型里产生出新的、在训练数据里没有出现过的新数据 x'。
本文里,我们介绍如何训练一个产生式神经元网络模型,它的输入是一个随机生成的向量(相当于不需要任何有意义的输入),而输出是一幅图像,其中有一个数字。换句话说,我们训练一个会写字(阿拉伯数字)的神经元网络模型。它“写”的一些数字如下图: 本文介绍如何训练一个产生式神经元网络模型,它的输入是一个随机生成的向量(相当于不需要任何有意义的输入),而输出是一幅图像,其中有一个数字。换句话说,我们训练一个会写字(阿拉伯数字)的神经元网络模型。它“写”的一些数字如下图:
<p align="center"> <p align="center">
<img src="./mnist_sample.png" width="300" height="300"><br/> <img src="./mnist_sample.png" width="300" height="300"><br/>
图1. GAN生成的MNIST例图 图1. GAN生成的MNIST例图
</p> </p>
现实中成功使用的生成式神经元网络模型往往接受有意义的输入。比如可能接受一幅低分辨率的图像,输出对应的高分辨率图像。这过程实际上是从大量数据学习得到模型,或者说归纳得到知识,然后用这些知识来补足图像的分辨率。 现实中成功使用的生成式神经元网络模型往往接受有意义的输入。比如可能接受一幅低分辨率的图像,输出对应的高分辨率图像。这样的模型被称为 conditional GAN这过程实际上是从大量数据学习得到模型,或者说归纳得到知识,然后用这些知识来补足图像的分辨率。
## 传统训练方式和对抗式训练 ## 传统训练方式和对抗式训练
...@@ -50,12 +50,9 @@ $$\min_G \max_D \frac{1}{N}\sum_{i=1}^N[\log D(x^i) + \log(1-D(G(z^i)))]$$ ...@@ -50,12 +50,9 @@ $$\min_G \max_D \frac{1}{N}\sum_{i=1}^N[\log D(x^i) + \log(1-D(G(z^i)))]$$
## 数据准备 ## 数据准备
todo(yi): from here on
### 数据介绍与下载 ### 数据介绍与下载
这章会用到两种数据,一种是简单的人造数据,一种是图片。
人造数据是二维均匀分布,由下面的代码生成: 这章会用到两种数据,一种是G的随机输入,一种是来自MNIST数据集的图片,其中一张是人类手写的一个数字。随机输入数据的生成方式如下:
```python ```python
# synthesize 2-D uniform data in gan_trainer.py:114 # synthesize 2-D uniform data in gan_trainer.py:114
...@@ -64,14 +61,14 @@ def load_uniform_data(): ...@@ -64,14 +61,14 @@ def load_uniform_data():
return data return data
``` ```
图片数据是MNIST手写数字,可由下面的代码下载: MNIST数据可以通过执行[get_mnist_data.sh](./data/get_mnist_data.sh)下载:
```bash ```bash
$cd data/ $cd data/
$./get_mnist_data.sh $./get_mnist_data.sh
``` ```
另一种更真实的图片数据是Cifar-10,可由下面的代码下载: 其实只需要换一种图像数据集,这个例子即可训练G来生成对应的类似图像。比如Cifar-10数据集可由执行[download_cifa.sh](./data/download_cifa.sh)下载:
```bash ```bash
$cd data/ $cd data/
...@@ -79,10 +76,12 @@ $./download_cifar.sh ...@@ -79,10 +76,12 @@ $./download_cifar.sh
``` ```
## 模型配置说明 ## 模型配置说明
由于对抗式生产网络涉及到多个神经网络,所以必须用paddle Python API来训练。下面的介绍也可以部分的拿来当作paddle Python API的使用说明。
由于对抗式生产网络涉及到多个神经网络,所以必须用PaddlePaddle Python API来训练。下面的介绍也可以部分的拿来当作PaddlePaddle Python API的使用说明。
### 模型结构 ### 模型结构
在文件gan_conf.py当中我们定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是分类器,**generator** 是生成器,**generator_training** 是生成器加分类器因为训练生成器时需要用到分类器提供目标函数。这个对应关系在下面这段代码中定义:
在文件`gan_conf.py`当中我们定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是分类器,**generator** 是生成器,**generator_training** 是生成器加分类器因为训练生成器时需要用到分类器提供目标函数。这个对应关系在下面这段代码中定义:
```python ```python
if is_generator_training: if is_generator_training:
...@@ -105,7 +104,13 @@ if is_generator: ...@@ -105,7 +104,13 @@ if is_generator:
outputs(generator(noise)) outputs(generator(noise))
``` ```
为了能够训练在gan_conf.py中定义的网络,我们需要如下几个步骤:初始化Paddle环境,解析设置,由设置创造GradientMachine以及由GradientMachine创造trainer。这几步分别由下面几段代码实现: 为了能够训练在`gan_conf.py`中定义的网络,我们需要如下几个步骤:
1. 初始化Paddle环境,
1. 解析设置,
1. 由设置创造GradientMachine以及由GradientMachine创造trainer。
这几步分别由下面几段代码实现:
```python ```python
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
...@@ -132,7 +137,7 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) ...@@ -132,7 +137,7 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
``` ```
为了能够平衡生成器和分类器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过GradientMachine的forward pass来计算。 为了能够平衡生成器和分类器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过调用`GradientMachine``forward`方法来计算。
```python ```python
def get_training_loss(training_machine, inputs): def get_training_loss(training_machine, inputs):
...@@ -155,7 +160,8 @@ copy_shared_parameters(gen_training_machine, generator_machine) ...@@ -155,7 +160,8 @@ copy_shared_parameters(gen_training_machine, generator_machine)
``` ```
### 数据定义 ### 数据定义
这里数据没有通过dataprovider提供,而是在gan_trainer.py里面直接产生data_batch并以Arguments的形式提供给trainer。
这里数据没有通过`dataprovider`提供,而是在`gan_trainer.py`里面直接产生读取minibatches,并以`Arguments`的形式提供给trainer。
```python ```python
def prepare_generator_data_batch(batch_size, noise): def prepare_generator_data_batch(batch_size, noise):
...@@ -173,7 +179,7 @@ gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) ...@@ -173,7 +179,7 @@ gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
### 算法配置 ### 算法配置
在这里,我们指定了模型的训练参数, 选择学习率和batch size。这里的beta1参数比默认值0.9小很多是为了使学习的过程更稳定。 在这里,我们指定了模型的训练参数, 选择learning rate和batch size。这里的`beta1`参数比默认值0.9小很多是为了使学习的过程更稳定。
```python ```python
settings( settings(
...@@ -183,7 +189,8 @@ settings( ...@@ -183,7 +189,8 @@ settings(
``` ```
##训练模型 ## 训练模型
用MNIST手写数字图片训练对抗式生成网络可以用如下的命令: 用MNIST手写数字图片训练对抗式生成网络可以用如下的命令:
```bash ```bash
...@@ -191,7 +198,8 @@ $python gan_trainer.py -d mnist --useGpu 1 ...@@ -191,7 +198,8 @@ $python gan_trainer.py -d mnist --useGpu 1
``` ```
## 应用模型 ## 应用模型
图片由训练好的生成器生成。以下的代码将噪音z输入到生成器 G 当中,通过向前传递得到生成的图片。
图片由训练好的生成器生成。以下的代码将随机向量输入到模型 G,通过向前传递得到生成的图片。
```python ```python
def get_fake_samples(generator_machine, batch_size, noise): def get_fake_samples(generator_machine, batch_size, noise):
...@@ -207,9 +215,11 @@ fake_samples = get_fake_samples(generator_machine, batch_size, noise) ...@@ -207,9 +215,11 @@ fake_samples = get_fake_samples(generator_machine, batch_size, noise)
``` ```
## 总结 ## 总结
本章中,我们介绍了对抗式生成网络的基本概念,训练方法以及如何用Paddle来实现。对抗式生成网络是现有生成模型当中非常重要的一种,它可以利用大量无标记数据来进行非监督学习,以寄希望能够得到对于复杂高维数据的一般有效的表示。
本章中,我们介绍了对抗式生成网络的基本概念,训练方法以及如何用PaddlePaddle来训练一个简单的图像生成模型。对抗式生成网络是一种新的训练生成模型的有效方法,我们期待看到它的更有意思的应用场景。
## 参考文献 ## 参考文献
1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. [Generative adversarial nets](https://arxiv.org/pdf/1406.2661v1.pdf)[C] Advances in Neural Information Processing Systems. 2014 1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. [Generative adversarial nets](https://arxiv.org/pdf/1406.2661v1.pdf)[C] Advances in Neural Information Processing Systems. 2014
2. Radford A, Metz L, Chintala S. [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2.pdf)[C] arXiv preprint arXiv:1511.06434. 2015 2. Radford A, Metz L, Chintala S. [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2.pdf)[C] arXiv preprint arXiv:1511.06434. 2015
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册