Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
defbd27f
B
book
项目概览
PaddlePaddle
/
book
通知
16
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
defbd27f
编写于
1月 06, 2017
作者:
W
wangyang59
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added history and introduction
上级
fce00da7
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
357 addition
and
15 deletion
+357
-15
gan/.gitignore
gan/.gitignore
+8
-0
gan/README.md
gan/README.md
+349
-15
gan/image/cifar_comparisons.jpg
gan/image/cifar_comparisons.jpg
+0
-0
gan/image/dcgan.png
gan/image/dcgan.png
+0
-0
gan/image/gan.png
gan/image/gan.png
+0
-0
gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
+0
-0
gan/image/gan_ig.png
gan/image/gan_ig.png
+0
-0
gan/image/mnist_sample.png
gan/image/mnist_sample.png
+0
-0
gan/image/uniform_sample.png
gan/image/uniform_sample.png
+0
-0
未找到文件。
gan/.gitignore
0 → 100644
浏览文件 @
defbd27f
output/
uniform_params/
cifar_params/
mnist_params/
*.log
*.pyc
data/mnist_data/
data/cifar-10-batches-py/
gan/README.md
浏览文件 @
defbd27f
...
...
@@ -2,20 +2,29 @@
## 背景介绍
本章我们介绍对抗式生成网络,也称为Generative Adversarial Network (GAN)
\[
[
1
](
#参考文献
)
\]
。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型(generative model),我们引入一个判别式神经元网络模型来构造优化目标函数。实验证明,
在图像自动生成、图像去噪、和缺失图像补全等应用里,这种方法可以训练处
一个能更逼近训练数据分布的生成式模型。
本章我们介绍对抗式生成网络,也称为Generative Adversarial Network (GAN)
\[
[
1
](
#参考文献
)
\]
。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型(generative model),我们引入一个判别式神经元网络模型来构造优化目标函数。实验证明,
这种方法可以训练出
一个能更逼近训练数据分布的生成式模型。
到目前为止,大部分取得好的应用效果的神经元网络模型都是有监督训练(supervised learning)的判别式模型(discriminative models),包括图像识别中使用的convolutional networks和在语音识别中使用的connectionist temporal classification (CTC) networks。在这些例子里,训练数据 X 都是带有标签 y 的——每张图片附带了一个或者多个tag,每段语音附带了一段对应的文本;而模型的输入是 X,输出是 y,训练得到的模型表示从X到y的映射函数 y=f(X)。
和判别式神经元网络模型相对的一类模型是生成式模型(generative models)。它们通常是通过非监督训练(unsupervised learning)来得到的。这类模型的训练数据里只有 X,没有y。训练的目标是希望模型能蕴含训练数据的统计分布信息,从而可以从训练好的模型里产生出新的、在训练数据里没有出现过的新数据 x'。
一些为人熟知的生成模型的例子包括受限玻尔兹曼机(Restricted Boltzmann Machine)
\[
[
4
](
#参考文献
)
\]
,深度玻尔兹曼机(Deep Boltzmann Machine)
\[
[
5
](
#参考文献
)
\]
,神经自回归分布估计(Neural Autoregressive Distribution Estimator)
\[
[
6
](
#参考文献
)
\]
等
。
生成模型在很多方面都有广泛应用。比如在图像处理方面的图像自动生成、图像去噪、和缺失图像补全等应用。比如在增强学习的条件下,可以根据之前观测到的数据和可能的操作来生成未来的数据,使得agent能够从中选择最佳的操作。比如在半监督(semi-supervised)学习的条件下,把生成模型生成的数据加入分类器训练当中,能够减少分类器训练对于标记数据数量的要求。真实世界中大量数据都是没有标注的,人为标注数据会耗费大量人力财力,这就使生成模型有了它的用武之地
。
近年出现了一些专门用来生成图像的模型,一种是变分自编码器(variational autoencoder)
\[
[
3
](
#参考文献
)
\]
,它是在概率图模型(probabilistic graphical model)的框架下面搭建了一个生成模型,对数据有完整的概率描述,训练时是通过调节参数来最大化数据的概率。用这种方法产生的图片,虽然似然(likelihood)比较高,但经常看起来比较模糊。另一种是像素循环神经网络(Pixel Recurrent Neural Network)
\[
[
7
](
#参考文献
)
\]
,它是通过根据周围的像素来一个像素一个像素的生成图片,但这种方法生成的图片在全局看来会不太一致。为了解决这些问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。
之前出现的生成模型,一般是直接构造模型$P_model(x;
\t
heta)$来模拟真实数据分布$P_data(x)$。而这个模拟的过程,通常是由最大似然(Maximum Likelihood)的办法来调节模型参数,使得观测到的真实数据在该模型下概率最大。这里模型的种类又可以分为两大类,一类是tractable的,一类是untractable的。第一类里的一个例子是像素循环神经网络(Pixel Recurrent Neural Network)
\[
[
7
](
#参考文献
)
\]
,它是用概率的链式规则把对于n维数据的概率分解成n个一维数据的概率相乘,也就是说根据周围的像素来一个像素一个像素的生成图片。这种方法的问题是对于一个n维的数据,需要n步才能生成,速度较慢,而且图片整体看来各处不连续。
为了能有更复杂的模型来模拟数据分布,人们提出了第二类untractable的模型,这样就只能用近似的办法来学习模型参数。近似的办法一种是构造一个似然的下限(Likelihood lower-bound),然后用变分的办法来提高这个下限的值,其中一个例子是变分自编码器(variational autoencoder)
\[
[
3
](
#参考文献
)
\]
。用这种方法产生的图片,虽然似然比较高,但经常看起来会比较模糊。近似的另一种办法是通过马尔可夫链-蒙地卡罗(Markov-Chain-Monte-Carlo)来取样本,比如深度玻尔兹曼机(Deep Boltzmann Machine)
\[
[
5
](
#参考文献
)
\]
就是用的这个方法。这种方法的问题是取样本的计算量非常大,而且没有办法并行化。
为了解决这些问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。它相比于前面提到的方法,具有生成网络结构灵活,产生样本快,生成图像看起来更真实的优点。
<p
align=
"center"
>
<img
src=
"./image/cifar_comparisons.jpg"
width=
"700"
height=
"200"
><br/>
图1. Cifar-10生成图像对比
</p>
本文介绍如何训练一个产生式神经元网络模型,它的输入是一个随机生成的向量(相当于不需要任何有意义的输入),而输出是一幅图像,其中有一个数字。换句话说,我们训练一个会写字(阿拉伯数字)的神经元网络模型。它“写”的一些数字如下图:
<p
align=
"center"
>
<img
src=
"./mnist_sample.png"
width=
"300"
height=
"300"
><br/>
<img
src=
"./
image/
mnist_sample.png"
width=
"300"
height=
"300"
><br/>
图1. GAN生成的MNIST例图
</p>
...
...
@@ -27,22 +36,22 @@
因为神经元网络是一个有向图,总是有输入和输出的。当我们用无监督学习方式来训练一个神经元网络,用于描述训练数据分布的时候,一个通常的学习目标是估计一组参数,使得输出和输入很接近 —— 或者说输入是什么输出就是什么。很多早期的生成式神经元网络模型,包括受限波尔茨曼机和 autoencoder 都是这么训练的。这种情况下优化目标经常是最小化输出和输入的差别。
<p
align=
"center"
>
<img
src=
"./
gan.png"
width=
"500"
height=
"3
00"
><br/>
<img
src=
"./
image/gan_ig.png"
width=
"500"
height=
"4
00"
><br/>
图2. GAN模型原理示意图
<a
href=
"https://
ishmaelbelghazi.github.io/ALI/
"
>
figure credit
</a>
<a
href=
"https://
arxiv.org/pdf/1701.00160v1.pdf
"
>
figure credit
</a>
</p>
对抗式训练里,我们用一个判别式模型 D 辅助构造优化目标函数,来训练一个生成式模型 G。如图2所示。具体训练流程是不断交替执行如下两步:
1.
更新模型 D:
1.
固定G的参数不变,对于一组随机输入,得到一组(产生式)输出,$X_f$,并且将其label成“假”。
1
.
从训练数据 X 采样一组 $X_r$,并且label为“真”。
1
.
用这两组数据更新模型 D,从而使D能够分辨G产生的数据和真实训练数据。
2
.
从训练数据 X 采样一组 $X_r$,并且label为“真”。
3
.
用这两组数据更新模型 D,从而使D能够分辨G产生的数据和真实训练数据。
1
.
更新模型 G:
2
.
更新模型 G:
1.
把G的输出和D的输入连接起来,得到一个网路。
1
.
给G一组随机输入,期待G的输出让D认为像是“真”的。
1
.
在D的输出端,优化目标是通过更新G的参数来最小化D的输出和“真”的差别。
2
.
给G一组随机输入,期待G的输出让D认为像是“真”的。
3
.
在D的输出端,优化目标是通过更新G的参数来最小化D的输出和“真”的差别。
上述方法实际上在优化如下目标:
...
...
@@ -53,14 +62,14 @@ $$\min_G \max_D \frac{1}{N}\sum_{i=1}^N[\log D(x^i) + \log(1-D(G(z^i)))]$$
在最早的对抗式生成网络的论文中,生成器和分类器用的都是全联接层。在附带的代码
[
`gan_conf.py`
](
./gan_conf.py
)
中,我们实现了一个类似的结构。G和D是由三层全联接层构成,并且在某些全联接层后面加入了批标准化层(batch normalization)。所用网络结构在图3中给出。
<p
align=
"center"
>
<img
src=
"./gan_conf_graph.png"
width=
"700"
height=
"400"
><br/>
<img
src=
"./
image/
gan_conf_graph.png"
width=
"700"
height=
"400"
><br/>
图3. GAN模型结构图
</p>
由于上面的这种网络都是由全联接层组成,所以没有办法很好的生成图片数据,也没有办法做的很深。所以在随后的论文中,人们提出了深度卷积对抗式生成网络(deep convolutional generative adversarial network or DCGAN)
\[
[
2
](
#参考文献
)
\]
。在DCGAN中,生成器 G 是由多个卷积转置层(transposed convolution)组成的,这样可以用更少的参数来生成质量更高的图片。具体网络结果可参见图4。而判别器是由多个卷积层组成。
<p
align=
"center"
>
<img
src=
"./dcgan.png"
width=
"700"
height=
"300"
><br/>
<img
src=
"./
image/
dcgan.png"
width=
"700"
height=
"300"
><br/>
图4. DCGAN生成器模型结构
<a
href=
"https://arxiv.org/pdf/1511.06434v2.pdf"
>
figure credit
</a>
</p>
...
...
@@ -131,6 +140,331 @@ settings(
```
### 模型结构
本章里我们主要用到两种模型。一种是基本的GAN模型,主要由全联接层搭建,在gan_conf.py里面定义。另一种是DCGAN模型,主要由卷基层搭建,在gan_conf_image.py里面定义。
```
python
# 下面这个函数定义了GAN模型里面的判别器结构
def
discriminator
(
sample
):
"""
discriminator ouputs the probablity of a sample is from generator
or real data.
The output has two dimenstional: dimension 0 is the probablity
of the sample is from generator and dimension 1 is the probabblity
of the sample is from real data.
"""
param_attr
=
ParamAttr
(
is_static
=
is_generator_training
)
bias_attr
=
ParamAttr
(
is_static
=
is_generator_training
,
initial_mean
=
1.0
,
initial_std
=
0
)
hidden
=
fc_layer
(
input
=
sample
,
name
=
"dis_hidden"
,
size
=
hidden_dim
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
act
=
ReluActivation
())
hidden2
=
fc_layer
(
input
=
hidden
,
name
=
"dis_hidden2"
,
size
=
hidden_dim
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
act
=
LinearActivation
())
hidden_bn
=
batch_norm_layer
(
hidden2
,
act
=
ReluActivation
(),
name
=
"dis_hidden_bn"
,
bias_attr
=
bias_attr
,
param_attr
=
ParamAttr
(
is_static
=
is_generator_training
,
initial_mean
=
1.0
,
initial_std
=
0.02
),
use_global_stats
=
False
)
return
fc_layer
(
input
=
hidden_bn
,
name
=
"dis_prob"
,
size
=
2
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
act
=
SoftmaxActivation
())
# 下面这个函数定义了GAN模型里面生成器的结构
def
generator
(
noise
):
"""
generator generates a sample given noise
"""
param_attr
=
ParamAttr
(
is_static
=
is_discriminator_training
)
bias_attr
=
ParamAttr
(
is_static
=
is_discriminator_training
,
initial_mean
=
1.0
,
initial_std
=
0
)
hidden
=
fc_layer
(
input
=
noise
,
name
=
"gen_layer_hidden"
,
size
=
hidden_dim
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
act
=
ReluActivation
())
hidden2
=
fc_layer
(
input
=
hidden
,
name
=
"gen_hidden2"
,
size
=
hidden_dim
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
act
=
LinearActivation
())
hidden_bn
=
batch_norm_layer
(
hidden2
,
act
=
ReluActivation
(),
name
=
"gen_layer_hidden_bn"
,
bias_attr
=
bias_attr
,
param_attr
=
ParamAttr
(
is_static
=
is_discriminator_training
,
initial_mean
=
1.0
,
initial_std
=
0.02
),
use_global_stats
=
False
)
return
fc_layer
(
input
=
hidden_bn
,
name
=
"gen_layer1"
,
size
=
sample_dim
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
act
=
LinearActivation
())
```
```
python
# 一个卷积/卷积转置层和一个批标准化层打包在一起
def
conv_bn
(
input
,
channels
,
imgSize
,
num_filters
,
output_x
,
stride
,
name
,
param_attr
,
bias_attr
,
param_attr_bn
,
bn
,
trans
=
False
,
act
=
ReluActivation
()):
"""
conv_bn is a utility function that constructs a convolution/deconv layer
with an optional batch_norm layer
:param bn: whether to use batch_norm_layer
:type bn: bool
:param trans: whether to use conv (False) or deconv (True)
:type trans: bool
"""
# calculate the filter_size and padding size based on the given
# imgSize and ouput size
tmp
=
imgSize
-
(
output_x
-
1
)
*
stride
if
tmp
<=
1
or
tmp
>
5
:
raise
ValueError
(
"conv input-output dimension does not fit"
)
elif
tmp
<=
3
:
filter_size
=
tmp
+
2
padding
=
1
else
:
filter_size
=
tmp
padding
=
0
print
(
imgSize
,
output_x
,
stride
,
filter_size
,
padding
)
if
trans
:
nameApx
=
"_conv"
else
:
nameApx
=
"_convt"
if
bn
:
conv
=
img_conv_layer
(
input
,
filter_size
=
filter_size
,
num_filters
=
num_filters
,
name
=
name
+
nameApx
,
num_channels
=
channels
,
act
=
LinearActivation
(),
groups
=
1
,
stride
=
stride
,
padding
=
padding
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
shared_biases
=
True
,
layer_attr
=
None
,
filter_size_y
=
None
,
stride_y
=
None
,
padding_y
=
None
,
trans
=
trans
)
conv_bn
=
batch_norm_layer
(
conv
,
act
=
act
,
name
=
name
+
nameApx
+
"_bn"
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr_bn
,
use_global_stats
=
False
)
return
conv_bn
else
:
conv
=
img_conv_layer
(
input
,
filter_size
=
filter_size
,
num_filters
=
num_filters
,
name
=
name
+
nameApx
,
num_channels
=
channels
,
act
=
act
,
groups
=
1
,
stride
=
stride
,
padding
=
padding
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
shared_biases
=
True
,
layer_attr
=
None
,
filter_size_y
=
None
,
stride_y
=
None
,
padding_y
=
None
,
trans
=
trans
)
return
conv
# 下面这个函数定义了DCGAN模型里面的生成器的结构
def
generator
(
noise
):
"""
generator generates a sample given noise
"""
param_attr
=
ParamAttr
(
is_static
=
is_discriminator_training
,
initial_mean
=
0.0
,
initial_std
=
0.02
)
bias_attr
=
ParamAttr
(
is_static
=
is_discriminator_training
,
initial_mean
=
0.0
,
initial_std
=
0.0
)
param_attr_bn
=
ParamAttr
(
is_static
=
is_discriminator_training
,
initial_mean
=
1.0
,
initial_std
=
0.02
)
h1
=
fc_layer
(
input
=
noise
,
name
=
"gen_layer_h1"
,
size
=
s8
*
s8
*
gf_dim
*
4
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
act
=
LinearActivation
())
h1_bn
=
batch_norm_layer
(
h1
,
act
=
ReluActivation
(),
name
=
"gen_layer_h1_bn"
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr_bn
,
use_global_stats
=
False
)
h2_bn
=
conv_bn
(
h1_bn
,
channels
=
gf_dim
*
4
,
output_x
=
s8
,
num_filters
=
gf_dim
*
2
,
imgSize
=
s4
,
stride
=
2
,
name
=
"gen_layer_h2"
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
param_attr_bn
=
param_attr_bn
,
bn
=
True
,
trans
=
True
)
h3_bn
=
conv_bn
(
h2_bn
,
channels
=
gf_dim
*
2
,
output_x
=
s4
,
num_filters
=
gf_dim
,
imgSize
=
s2
,
stride
=
2
,
name
=
"gen_layer_h3"
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
param_attr_bn
=
param_attr_bn
,
bn
=
True
,
trans
=
True
)
return
conv_bn
(
h3_bn
,
channels
=
gf_dim
,
output_x
=
s2
,
num_filters
=
c_dim
,
imgSize
=
sample_dim
,
stride
=
2
,
name
=
"gen_layer_h4"
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
param_attr_bn
=
param_attr_bn
,
bn
=
False
,
trans
=
True
,
act
=
TanhActivation
())
# 下面这个函数定义了DCGAN模型里面的判别器结构
def
discriminator
(
sample
):
"""
discriminator ouputs the probablity of a sample is from generator
or real data.
The output has two dimenstional: dimension 0 is the probablity
of the sample is from generator and dimension 1 is the probabblity
of the sample is from real data.
"""
param_attr
=
ParamAttr
(
is_static
=
is_generator_training
,
initial_mean
=
0.0
,
initial_std
=
0.02
)
bias_attr
=
ParamAttr
(
is_static
=
is_generator_training
,
initial_mean
=
0.0
,
initial_std
=
0.0
)
param_attr_bn
=
ParamAttr
(
is_static
=
is_generator_training
,
initial_mean
=
1.0
,
initial_std
=
0.02
)
h0
=
conv_bn
(
sample
,
channels
=
c_dim
,
imgSize
=
sample_dim
,
num_filters
=
df_dim
,
output_x
=
s2
,
stride
=
2
,
name
=
"dis_h0"
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
param_attr_bn
=
param_attr_bn
,
bn
=
False
)
h1_bn
=
conv_bn
(
h0
,
channels
=
df_dim
,
imgSize
=
s2
,
num_filters
=
df_dim
*
2
,
output_x
=
s4
,
stride
=
2
,
name
=
"dis_h1"
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
param_attr_bn
=
param_attr_bn
,
bn
=
True
)
h2_bn
=
conv_bn
(
h1_bn
,
channels
=
df_dim
*
2
,
imgSize
=
s4
,
num_filters
=
df_dim
*
4
,
output_x
=
s8
,
stride
=
2
,
name
=
"dis_h2"
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
param_attr_bn
=
param_attr_bn
,
bn
=
True
)
return
fc_layer
(
input
=
h2_bn
,
name
=
"dis_prob"
,
size
=
2
,
bias_attr
=
bias_attr
,
param_attr
=
param_attr
,
act
=
SoftmaxActivation
())
```
在文件
`gan_conf.py`
当中我们定义了三个网络,
**generator_training**
,
**discriminator_training**
and
**generator**
. 和前文提到的模型结构的关系是:
**discriminator_training**
是分类器,
**generator**
是生成器,
**generator_training**
是生成器加分类器因为训练生成器时需要用到分类器提供目标函数。这个对应关系在下面这段代码中定义:
...
...
@@ -185,8 +519,8 @@ I0105 17:16:37.172737 20517 TrainerInternal.cpp:165] Batch=100 samples=12800 Av
为了能够训练在gan_conf.py中定义的网络,我们需要如下几个步骤:
1.
初始化Paddle环境,
1
.
解析设置,
1
.
由设置创造GradientMachine以及由GradientMachine创造trainer。
2
.
解析设置,
3
.
由设置创造GradientMachine以及由GradientMachine创造trainer。
这几步分别由下面几段代码实现:
...
...
gan/image/cifar_comparisons.jpg
0 → 100644
浏览文件 @
defbd27f
456.9 KB
gan/dcgan.png
→
gan/
image/
dcgan.png
浏览文件 @
defbd27f
文件已移动
gan/gan.png
→
gan/
image/
gan.png
浏览文件 @
defbd27f
文件已移动
gan/gan_conf_graph.png
→
gan/
image/
gan_conf_graph.png
浏览文件 @
defbd27f
文件已移动
gan/image/gan_ig.png
0 → 100644
浏览文件 @
defbd27f
229.4 KB
gan/mnist_sample.png
→
gan/
image/
mnist_sample.png
浏览文件 @
defbd27f
文件已移动
gan/uniform_sample.png
→
gan/
image/
uniform_sample.png
浏览文件 @
defbd27f
文件已移动
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录