提交 68d54ecb 编写于 作者: W wangyang59

code modification

上级 a0411e3c
...@@ -2,113 +2,107 @@ ...@@ -2,113 +2,107 @@
## 背景介绍 ## 背景介绍
本章我们介绍对抗式生成网络,也称为Generative Adversarial Network (GAN) \[[1](#参考文献)\]。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型(generative model),我们引入一个判别式神经元网络模型来构造优化目标函数。实验证明,这种方法可以训练出一个能更逼近训练数据分布的生成式模型 之前的几章中,我们主要介绍的深度学习模型都是在有监督学习(supervised learning)条件下的判别式模型(discriminative models)。在这些例子里,训练数据 X 都是带有标签 y 的,如图像识别中的类别标号,或是语音识别中对应的真实文本。模型的输入是 X,输出是 y,训练得到的模型表示从X到y的映射函数 y=f(X)
到目前为止,大部分取得好的应用效果的神经元网络模型都是有监督训练(supervised learning)的判别式模型(discriminative models),包括图像识别中使用的convolutional networks和在语音识别中使用的connectionist temporal classification (CTC) networks。在这些例子里,训练数据 X 都是带有标签 y 的——每张图片附带了一个或者多个tag,每段语音附带了一段对应的文本;而模型的输入是 X,输出是 y,训练得到的模型表示从X到y的映射函数 y=f(X) 和判别式网络模型相对的一类模型是生成式模型(generative models)。它们通常是通过非监督训练(unsupervised learning)来得到的。这类模型的训练数据里只有 X,没有y。训练的目标是希望模型能蕴含训练数据的统计分布信息,从而可以从训练好的模型里产生出新的、在训练数据里没有出现过的新数据 X'
和判别式神经元网络模型相对的一类模型是生成式模型(generative models)。它们通常是通过非监督训练(unsupervised learning)来得到的。这类模型的训练数据里只有 X,没有y。训练的目标是希望模型能蕴含训练数据的统计分布信息,从而可以从训练好的模型里产生出新的、在训练数据里没有出现过的新数据 x' 生成模型在很多方向都有着广泛的应用。在图像处理方向,生成模型可以用来做图像自动生成、图像去噪、和缺失图像补全等应用。在半监督(semi-supervised)学习的条件下,把生成模型生成的数据加入分类器训练当中,能够减少分类器训练对于标记数据数量的要求。真实世界中大量数据都是没有标注的,人为标注数据会耗费大量人力财力,这就使生成模型有了它的用武之地。研究生成模型的另一个动机是,人们认为如果能够生成很好的数据,那么很可能这个生成模型就学习到了这组数据的一个很好的通用表示(representation),随后就可以用这个学到的表示来完成其他的一些任务
生成模型在很多方面都有广泛应用。比如在图像处理方面的图像自动生成、图像去噪、和缺失图像补全等应用。比如在增强学习的条件下,可以根据之前观测到的数据和可能的操作来生成未来的数据,使得agent能够从中选择最佳的操作。比如在半监督(semi-supervised)学习的条件下,把生成模型生成的数据加入分类器训练当中,能够减少分类器训练对于标记数据数量的要求。真实世界中大量数据都是没有标注的,人为标注数据会耗费大量人力财力,这就使生成模型有了它的用武之地。研究生成模型的另一个动机是,人们认为如果能够生成很好的数据,那么很可能这个生成模型就学习到了这组数据的一个很好的通用表示(representation),就可以用这个学到的表示来完成其他的一些任务。 之前出现的生成模型,一般是直接构造模型$P_{model}(x; \theta)$来模拟真实数据分布$P_{data}(x)$。而这个模拟的过程,通常是由最大似然(Maximum Likelihood)的办法来调节模型参数,使得观测到的真实数据在该模型下概率最大。现在常用的生成模型有以下几种:
之前出现的生成模型,一般是直接构造模型$P_{model}(x; \theta)$来模拟真实数据分布$P_{data}(x)$。而这个模拟的过程,通常是由最大似然(Maximum Likelihood)的办法来调节模型参数,使得观测到的真实数据在该模型下概率最大。这里模型的种类又可以分为两大类,一类是tractable的,一类是untractable的。第一类里的一个例子是像素循环神经网络(Pixel Recurrent Neural Network)\[[2](#参考文献)\],它是用概率的链式规则把对于n维数据的概率分解成n个一维数据的概率相乘,也就是说根据周围的像素来一个像素一个像素的生成图片。这种方法的问题是对于一个n维的数据,需要n步才能生成,速度较慢,而且图片整体看来各处不太连续。 1. 深度玻尔兹曼机(Deep Boltzmann Machine)\[[4](#参考文献)\]: 深度玻尔兹曼机是在概率图模型(probabilistc graphical model)的框架下由多个隐层(hidden layer)搭建的无向图模型(Markov random field)。具体的模型结构可参见图1,图中各层之间的是通过受限玻尔兹曼机(restricted Boltzmann machine)的结构来连接的。这类模型参数的学习一般需要通过马尔可夫链-蒙地卡罗(Markov-Chain-Monte-Carlo)的方法来取样本,所以计算量会比较大。
2. 变分自编码器(variational autoencoder)\[[3](#参考文献)\]:它是在概率图模型的框架下搭建的有向图模型,并结合了深度学习和统计推断方法,希望将数据表示在一个隐层空间(z)中,而其中仍用神经网络做非线性映射,具体模型结构可参加图1。这类模型的参数学习用的近似方法是构造一个似然的下限(Likelihood lower-bound),然后用变分的办法来提高这个下限的值,从而达到提高数据似然的目的。这种方法的问题是产生的图片看起来会比较模糊。
为了能有更复杂的模型来模拟数据分布,人们提出了第二类untractable的模型,这样就只能用近似的办法来学习模型参数。近似的办法一种是构造一个似然的下限(Likelihood lower-bound),然后用变分的办法来提高这个下限的值,其中一个例子是变分自编码器(variational autoencoder)\[[3](#参考文献)\]。用这种方法产生的图片,虽然似然比较高,但经常看起来会比较模糊。近似的另一种办法是通过马尔可夫链-蒙地卡罗(Markov-Chain-Monte-Carlo)来取样本,比如深度玻尔兹曼机(Deep Boltzmann Machine)\[[4](#参考文献)\]就是用的这个方法。这种方法的问题是取样本的计算量非常大,而且没有办法并行化。 3. 像素循环神经网络(Pixel Recurrent Neural Network)\[[2](#参考文献)\]:它是对每个像素相对于周围像素的条件概率进行建模,也就是说根据周围的像素来一个像素一个像素的生成图片。例如图1中红色像素$x_i$的值就是依赖于之前生成的所有蓝色像素。这种方法的问题是对于一个n维的数据,需要n步才能生成,速度较慢。
为了解决这些问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。它相比于前面提到的方法,具有生成网络结构灵活,产生样本快,生成图像看起来更真实的优点。下面的图1就对比了上面介绍的几种方法在生成CIFAR-10图片时的效果。
<p align="center"> <p align="center">
<img src="./image/cifar_comparisons.jpg" width="1000" height="300"><br/> <img src="./image/background_intro.png" width="800" height="300"><br/>
图1. Cifar-10生成图像对比 图1. 生成模型概览
</p> </p>
为了解决上面这些模型的问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。它相比于前面提到的方法,具有生成网络结构灵活,产生样本快,生成图像看起来更真实的优点。对抗式生成网络,也称为Generative Adversarial Network (GAN) \[[1](#参考文献)\]。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型,我们引入一个判别式神经元网络模型来构造优化目标函数。
## 效果展示 ## 效果展示
本章将介绍如何训练一个对抗式生成网络,它的输入是一个随机生成的向量(相当于不需要任何有意义的输入),而输出是一幅图像,其中有一个数字。换句话说,我们训练一个会写字(阿拉伯数字)的神经元网络模型。它“写”的一些数字如下图: 一个训练好的对抗式生成网络,它的输入是一个随机生成的向量(相当于不需要任何有意义的输入),而输出是一个和训练数据相类似的数据样本。如果训练数据是二维单位均匀分布的数据,那么输出的也是二维单位均匀分布的数据(参见图2左);如果训练数据是MNIST手写数字图片,那么输出也是类似MNIST的数字图片(参见图2中,其中每个数字是由一个随机向量产生);如果训练数据是CIFAR物体图片,那么输出也是类似物体的图片(参见图2右)。
<p align="center"> <p align="center">
<img src="./image/mnist_sample.png" width="300" height="300"><br/> <img src="./image/gan_samples.jpg" width="800" height="300"><br/>
图2. GAN生成的MNIST例 图2. GAN生成效果展示
</p> </p>
## 模型概览 ## 模型概览
### 对抗式网络结构
对抗式生成网络的基本结构是将一个已知概率分布的随机变量$z$,通过参数化的概率生成模型(通常是用一个神经网络模型来进行参数化),变换后得到一个生成的概率分布(图3中绿色的分布)。训练生成模型的过程就是调节生成模型的参数,使得生成的概率分布趋向于真实数据的概率分布(图3中蓝色的分布)。 对抗式生成网络的基本结构是将一个已知概率分布的随机变量$z$,通过参数化的概率生成模型(通常是用一个神经网络模型来进行参数化),变换后得到一个生成的概率分布(图3中绿色的分布)。训练生成模型的过程就是调节生成模型的参数,使得生成的概率分布趋向于真实数据的概率分布(图3中蓝色的分布)。
对抗式生成网络和之前的生成模型最大的创新就在于,用一个判别式神经网络来描述生成的概率分布和真实数据概率分布之间的差别。也就是说,我们用一个判别式模型 D 辅助构造优化目标函数,来训练一个生成式模型 G。G和D在训练时是处在相互对抗的角色下,G的目标是尽量生成和真实数据看起来相似的伪数据,从而使得D无法分别数据的真伪;而D的目标是能尽量分别出哪些是真实数据,哪些是G生成的伪数据。两者在竞争的条件下,能够相互提高各自的能力,最后收敛到一个均衡点:生成器生成的数据分布和真实数据分布完全一样,而判别器完全无法区分数据的真伪。 对抗式生成网络和之前的生成模型最大的创新就在于,用一个判别式神经网络来描述生成的概率分布和真实数据概率分布之间的差别。也就是说,我们用一个判别式模型 D 辅助构造优化目标函数,来训练一个生成式模型 G。G和D在训练时是处在相互对抗的角色下,G的目标是尽量生成和真实数据看起来相似的伪数据,从而使得D无法分别数据的真伪;而D的目标是能尽量分别出哪些是真实数据,哪些是G生成的伪数据。两者在竞争的条件下,能够相互提高各自的能力,最后收敛到一个均衡点:生成器生成的数据分布和真实数据分布完全一样,而判别器完全无法区分数据的真伪。
<p align="center"> ### 对抗式训练方法
<img src="./image/gan_openai.png" width="700" height="300"><br/> 对抗式训练里,具体训练流程是不断交替执行如下两步(参见图3):
图3. GAN模型原理示意图
<a href="https://openai.com/blog/generative-models/">figure credit</a>
</p>
对抗式训练里,具体训练流程是不断交替执行如下两步(参见图4): 1. 更新判别器D:
1. 固定G的参数不变,对于一组随机输入,得到一组(产生式)输出,$X_f$,并将其标号(label)设置为"假"。
1. 更新模型 D: 2. 从训练数据 X 采样一组 $X_r$,并将其标号设置为"真"。
1. 固定G的参数不变,对于一组随机输入,得到一组(产生式)输出,$X_f$,并且将其label成“假”。
2. 从训练数据 X 采样一组 $X_r$,并且label为“真”。
3. 用这两组数据更新模型 D,从而使D能够分辨G产生的数据和真实训练数据。 3. 用这两组数据更新模型 D,从而使D能够分辨G产生的数据和真实训练数据。
2. 更新模型 G: 2. 更新生成器G:
1. 把G的输出和D的输入连接起来,得到一个网路。 1. 把G的输出和D的输入连接起来,得到一个网路。
2. 给G一组随机输入,期待G的输出让D认为像是“真”的 2. 给G一组随机输入,输出生成数据
3. 在D的输出端,优化目标是通过更新G的参数来最小化D的输出和“真”的差别。 3. 将G生成的数据输入D。在D的输出端,优化目标是通过更新G的参数来最小化D对生成数据的判别结果和“真”的差别。
<p align="center"> <p align="center">
<img src="./image/gan_ig.png" width="500" height="400"><br/> <img src="./image/gan_ig.png" width="500" height="400"><br/>
图4. GAN模型训练流程图 图3. GAN模型训练流程图 [6]
<a href="https://arxiv.org/pdf/1701.00160v1.pdf">figure credit</a>
</p> </p>
上述方法实际上在优化如下目标: 上述方法实际上在优化如下目标:
$$\min_G \max_D \frac{1}{N}\sum_{i=1}^N[\log D(x^i) + \log(1-D(G(z^i)))]$$ $$\min_G \max_D \frac{1}{N}\sum_{i=1}^N[\log D(x^i) + \log(1-D(G(z^i)))]$$
其中$x$是真实数据,$z$是随机产生的输入,$N$是训练数据的数量。这个损失函数的意思是:真实数据被分类为真的概率加上伪数据被分类为假的概率。因为上述两步交替优化G生成的结果的仿真程度(看起来像x)和D分辨G的生成结果和x的能力,所以这个方法被称为对抗(adversarial)方法。 其中$x$是真实数据,$z$是随机产生的输入,$N$是训练数据的数量。这个损失函数的意思是:真实数据被分类为真的概率加上伪数据被分类为假的概率。因为上述两步交替优化G生成结果的仿真程度(看起来像x),和D分辨真伪数据的能力,所以这个方法被称为对抗(adversarial)方法。
在最早的对抗式生成网络的论文中,生成器和分类器用的都是全联接层。在附带的代码[`gan_conf.py`](./gan_conf.py)中,我们实现了一个类似的结构。G和D是由三层全联接层构成,并且在某些全联接层后面加入了批标准化层(batch normalization)。所用网络结构在图5中给出。 ### 基本GAN模型
在最早的对抗式生成网络的论文中,生成器和判别器用的都是全连接层。我们在本章实现了一个类似的结构。G和D是由三层全连接层构成,并且在某些全联接层后面加入了批标准化层(batch normalization)。所用网络结构在图5中给出。
<p align="center"> <p align="center">
<img src="./image/gan_conf_graph.png" width="700" height="400"><br/> <img src="./image/gan_conf_graph.png" width="700" height="400"><br/>
图5. GAN模型结构图 图5. GAN模型结构图
</p> </p>
由于上面的这种网络都是由全联接层组成,所以没有办法很好的生成图片数据,也没有办法做的很深。所以在随后的论文中,人们提出了深度卷积对抗式生成网络(deep convolutional generative adversarial network or DCGAN)\[[5](#参考文献)\]。在DCGAN中,生成器 G 是由多个卷积转置层(transposed convolution)组成的,这样可以用更少的参数来生成质量更高的图片。具体网络结果可参见图6。而判别器是由多个卷积层组成。 ### DCGAN模型
由于上面的这种网络都是由全连接层组成,所以没有办法很好的生成图片数据,也没有办法做的很深。所以在随后的论文中,人们提出了深度卷积对抗式生成网络(deep convolutional generative adversarial network or DCGAN)\[[5](#参考文献)\]。在DCGAN中,生成器 G 是由多个卷积转置层(transposed convolution)组成的,这样可以用更少的参数来生成质量更高的图片,而判别器是由多个卷积层组成。卷积转置层和卷积层的关系是,卷积转置层的向前传递(feedforward)操作类似于卷积层的向后传递(back-propagation)操作。也就是说,卷积转着层向前传递时,输入图片的每个像素都和整个卷积核(kernal)相乘,然后把结果叠加到输出图片的相应位置\[[7](#参考文献)\]。 具体网络结构可参见图6。
<p align="center"> <p align="center">
<img src="./image/dcgan.png" width="700" height="300"><br/> <img src="./image/dcgan_conf_graph.png" width="700" height="400"><br/>
图6. DCGAN生成器模型结构 图6. DCGAN生成器模型结构
<a href="https://arxiv.org/pdf/1511.06434v2.pdf">figure credit</a>
</p> </p>
## 数据 ## 数据准备
这章会用到三种数据,一种是二维均匀分布随机数(后面会用基本GAN模型训练),一种是MNIST手写数字图片(后面会用DCGAN模型训练),一种是CIFAR-10物体图片(后面会用DCGAN模型训练)。
这章会用到两种数据,一种是G的随机输入,另一种是来自MNIST数据集的图片,其中一张是人类手写的一个数字。随机输入数据的生成方式如下: 二维均匀分布随机数的生成方式如下:
```python ```python
# 合成2-D均匀分布数据 gan_trainer.py:114 # 合成二维均匀分布随机数 gan_trainer.py:114
# numpy.random.rand会生成0-1之间的均匀分布随机数
# 后面的参数是让其生成两百万个这样的随机数,然后排成1000000*2的矩阵,
# 这样可以当做1000000个二维均匀分布随机数来使用
def load_uniform_data(): def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32') data = numpy.random.rand(1000000, 2).astype('float32')
return data return data
``` ```
MNIST数据可以通过执行[get_mnist_data.sh](./data/get_mnist_data.sh)下载: MNIST/CIFAR数据可以分别通过执行下面的命令来下载:
```bash ```bash
$cd data/ $cd data/
$./get_mnist_data.sh $./get_mnist_data.sh
$./get_cifar_data.sh
``` ```
其实只需要换一种图像数据集,这个例子即可训练G来生成对应的类似图像。比如Cifar-10数据集可由执行[download_cifa.sh](./data/download_cifa.sh)下载: ## 模型配置说明
```bash
$cd data/
$./download_cifar.sh
```
## 模型配置
由于对抗式生产网络涉及到多个神经网络,所以必须用PaddlePaddle Python API来训练。下面的介绍也可以部分的拿来当作PaddlePaddle Python API的使用说明。 由于对抗式生产网络涉及到多个神经网络,所以必须用PaddlePaddle Python API来训练。下面的介绍也可以部分的拿来当作PaddlePaddle Python API的使用说明。
...@@ -137,36 +131,62 @@ gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) ...@@ -137,36 +131,62 @@ gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
### 算法配置 ### 算法配置
在这里,我们指定了模型的训练参数, 选择learning rate和batch size。这里`beta1`参数比默认值0.9小很多是为了使学习的过程更稳定。 在这里,我们指定了模型的训练参数, 选择学习率(learning rate)和batch size。这里用到的优化方法是AdamOptimizer\[[8](#参考文献)\],它`beta1`参数比默认值0.9小很多是为了使学习的过程更稳定。
```python ```python
settings( settings(
batch_size=128, batch_size=128,
learning_rate=1e-4, learning_rate=1e-4,
learning_method=AdamOptimizer(beta1=0.5)) learning_method=AdamOptimizer(beta1=0.5))
``` ```
### 模型结构 ### 模型结构
本章里我们主要用到两种模型。一种是基本的GAN模型,主要由全联接层搭建,在gan_conf.py里面定义。另一种是DCGAN模型,主要由卷基层搭建,在gan_conf_image.py里面定义。 本章里我们主要用到两种模型。一种是基本的GAN模型,主要由全连接层搭建,在gan_conf.py里面定义。另一种是DCGAN模型,主要由卷积层搭建,在gan_conf_image.py里面定义。
#### 对抗式网络基本框架
在文件`gan_conf.py``gan_conf_image.py`当中我们都定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是判别器,**generator** 是生成器,**generator_training** 是生成器加上判别器,这样构造的原因是因为训练生成器时需要用到判别器提供目标函数。这个对应关系在下面这段代码中定义:
```python ```python
# 下面这个函数定义了GAN模型里面的判别器结构 if is_generator_training:
def discriminator(sample): noise = data_layer(name="noise", size=noise_dim)
""" # 函数generator定义了生成器的结构,生成器输入噪音z,输出伪数据。
discriminator ouputs the probablity of a sample is from generator sample = generator(noise)
or real data.
The output has two dimenstional: dimension 0 is the probablity if is_discriminator_training:
of the sample is from generator and dimension 1 is the probabblity # 判别器输入为真实数据或者伪数据
of the sample is from real data. sample = data_layer(name="sample", size=sample_dim)
"""
param_attr = ParamAttr(is_static=is_generator_training) if is_generator_training or is_discriminator_training:
label = data_layer(name="label", size=1)
# 函数discriminator定义了判别器的结构,判别器输入样本数据,
# 输出为该样本为真的概率
prob = discriminator(sample)
cost = cross_entropy(input=prob, label=label)
classification_error_evaluator(
input=prob, label=label, name=mode + '_error')
outputs(cost)
if is_generator:
noise = data_layer(name="noise", size=noise_dim)
outputs(generator(noise))
```
不同的generator和discriminator函数就决定了不同的GAN模型的实现。下面我们就分别介绍基本GAN模型和DCGAN模型里面这两个函数的定义。
#### 基本GAN模型生成器和判别器定义
```python
# 下面这个函数定义了基本GAN模型里面生成器的结构,可参见图5
def generator(noise):
# 这里is_static=is_discriminator_training的意思是在训练判别器时,
# 生成器的参数保持不变。
param_attr = ParamAttr(is_static=is_discriminator_training)
# 这里bias的初始值设成1.0是为了让神经元都尽量处于开启(activated)状态
# 这样便于Relu层向回传递导数信号
bias_attr = ParamAttr( bias_attr = ParamAttr(
is_static=is_generator_training, initial_mean=1.0, initial_std=0) is_static=is_discriminator_training, initial_mean=1.0, initial_std=0)
hidden = fc_layer( hidden = fc_layer(
input=sample, input=noise,
name="dis_hidden", name="gen_layer_hidden",
size=hidden_dim, size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
...@@ -174,7 +194,7 @@ def discriminator(sample): ...@@ -174,7 +194,7 @@ def discriminator(sample):
hidden2 = fc_layer( hidden2 = fc_layer(
input=hidden, input=hidden,
name="dis_hidden2", name="gen_hidden2",
size=hidden_dim, size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
...@@ -183,33 +203,33 @@ def discriminator(sample): ...@@ -183,33 +203,33 @@ def discriminator(sample):
hidden_bn = batch_norm_layer( hidden_bn = batch_norm_layer(
hidden2, hidden2,
act=ReluActivation(), act=ReluActivation(),
name="dis_hidden_bn", name="gen_layer_hidden_bn",
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=ParamAttr( param_attr=ParamAttr(
is_static=is_generator_training, initial_mean=1.0, is_static=is_discriminator_training,
initial_mean=1.0,
initial_std=0.02), initial_std=0.02),
use_global_stats=False) use_global_stats=False)
return fc_layer( return fc_layer(
input=hidden_bn, input=hidden_bn,
name="dis_prob", name="gen_layer1",
size=2, size=sample_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=SoftmaxActivation()) act=LinearActivation())
# 下面这个函数定义了GAN模型里面生成器的结构 # 下面这个函数定义了基本GAN模型里面的判别器结构,可参见图5
def generator(noise): def discriminator(sample):
""" # 这里is_static=is_generator_training的意思是在训练生成器时,
generator generates a sample given noise # 判别器的参数保持不变。
""" param_attr = ParamAttr(is_static=is_generator_training)
param_attr = ParamAttr(is_static=is_discriminator_training)
bias_attr = ParamAttr( bias_attr = ParamAttr(
is_static=is_discriminator_training, initial_mean=1.0, initial_std=0) is_static=is_generator_training, initial_mean=1.0, initial_std=0)
hidden = fc_layer( hidden = fc_layer(
input=noise, input=sample,
name="gen_layer_hidden", name="dis_hidden",
size=hidden_dim, size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
...@@ -217,7 +237,7 @@ def generator(noise): ...@@ -217,7 +237,7 @@ def generator(noise):
hidden2 = fc_layer( hidden2 = fc_layer(
input=hidden, input=hidden,
name="gen_hidden2", name="dis_hidden2",
size=hidden_dim, size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
...@@ -226,23 +246,23 @@ def generator(noise): ...@@ -226,23 +246,23 @@ def generator(noise):
hidden_bn = batch_norm_layer( hidden_bn = batch_norm_layer(
hidden2, hidden2,
act=ReluActivation(), act=ReluActivation(),
name="gen_layer_hidden_bn", name="dis_hidden_bn",
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=ParamAttr( param_attr=ParamAttr(
is_static=is_discriminator_training, is_static=is_generator_training, initial_mean=1.0,
initial_mean=1.0,
initial_std=0.02), initial_std=0.02),
use_global_stats=False) use_global_stats=False)
return fc_layer( return fc_layer(
input=hidden_bn, input=hidden_bn,
name="gen_layer1", name="dis_prob",
size=sample_dim, size=2,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=LinearActivation()) act=SoftmaxActivation())
``` ```
#### DCGAN模型生成器和判别器定义
```python ```python
# 一个卷积/卷积转置层和一个批标准化层打包在一起 # 一个卷积/卷积转置层和一个批标准化层打包在一起
def conv_bn(input, def conv_bn(input,
...@@ -258,17 +278,8 @@ def conv_bn(input, ...@@ -258,17 +278,8 @@ def conv_bn(input,
bn, bn,
trans=False, trans=False,
act=ReluActivation()): act=ReluActivation()):
""" # 根据输入图片的大小(imgSize)和输出图片的大小(output_x),
conv_bn is a utility function that constructs a convolution/deconv layer # 来计算所需的卷积核大小(filter_size)和边界补全大小(padding)
with an optional batch_norm layer
:param bn: whether to use batch_norm_layer
:type bn: bool
:param trans: whether to use conv (False) or deconv (True)
:type trans: bool
"""
# calculate the filter_size and padding size based on the given
# imgSize and ouput size
tmp = imgSize - (output_x - 1) * stride tmp = imgSize - (output_x - 1) * stride
if tmp <= 1 or tmp > 5: if tmp <= 1 or tmp > 5:
raise ValueError("conv input-output dimension does not fit") raise ValueError("conv input-output dimension does not fit")
...@@ -279,21 +290,24 @@ def conv_bn(input, ...@@ -279,21 +290,24 @@ def conv_bn(input,
filter_size = tmp filter_size = tmp
padding = 0 padding = 0
print(imgSize, output_x, stride, filter_size, padding)
if trans: if trans:
nameApx = "_conv" nameApx = "_conv"
else: else:
nameApx = "_convt" nameApx = "_convt"
# 如果conv层后面跟batchNorm层,那么conv层的activation必须是线性激发
if bn: if bn:
conv_act = LinearActivation()
else:
conv_act = act
conv = img_conv_layer( conv = img_conv_layer(
input, input,
filter_size=filter_size, filter_size=filter_size,
num_filters=num_filters, num_filters=num_filters,
name=name + nameApx, name=name + nameApx,
num_channels=channels, num_channels=channels,
act=LinearActivation(), act=conv_act,
groups=1, groups=1,
stride=stride, stride=stride,
padding=padding, padding=padding,
...@@ -306,6 +320,7 @@ def conv_bn(input, ...@@ -306,6 +320,7 @@ def conv_bn(input,
padding_y=None, padding_y=None,
trans=trans) trans=trans)
if bn:
conv_bn = batch_norm_layer( conv_bn = batch_norm_layer(
conv, conv,
act=act, act=act,
...@@ -313,39 +328,17 @@ def conv_bn(input, ...@@ -313,39 +328,17 @@ def conv_bn(input,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr_bn, param_attr=param_attr_bn,
use_global_stats=False) use_global_stats=False)
return conv_bn return conv_bn
else: else:
conv = img_conv_layer(
input,
filter_size=filter_size,
num_filters=num_filters,
name=name + nameApx,
num_channels=channels,
act=act,
groups=1,
stride=stride,
padding=padding,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=True,
layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans)
return conv return conv
# 下面这个函数定义了DCGAN模型里面的生成器的结构 # 下面这个函数定义了DCGAN模型里面的生成器的结构,可参见图6
def generator(noise): def generator(noise):
""" # 这里的参数初始化设置参考了DCGAN论文里的建议和实际调试的结果
generator generates a sample given noise
"""
param_attr = ParamAttr( param_attr = ParamAttr(
is_static=is_discriminator_training, initial_mean=0.0, initial_std=0.02) is_static=is_discriminator_training, initial_mean=0.0, initial_std=0.02)
bias_attr = ParamAttr( bias_attr = ParamAttr(
is_static=is_discriminator_training, initial_mean=0.0, initial_std=0.0) is_static=is_discriminator_training, initial_mean=0.0, initial_std=0.0)
param_attr_bn = ParamAttr( param_attr_bn = ParamAttr(
is_static=is_discriminator_training, initial_mean=1.0, initial_std=0.02) is_static=is_discriminator_training, initial_mean=1.0, initial_std=0.02)
...@@ -408,20 +401,12 @@ def generator(noise): ...@@ -408,20 +401,12 @@ def generator(noise):
trans=True, trans=True,
act=TanhActivation()) act=TanhActivation())
# 下面这个函数定义了DCGAN模型里面的判别器结构 # 下面这个函数定义了DCGAN模型里面的判别器结构,可参见图6
def discriminator(sample): def discriminator(sample):
"""
discriminator ouputs the probablity of a sample is from generator
or real data.
The output has two dimenstional: dimension 0 is the probablity
of the sample is from generator and dimension 1 is the probabblity
of the sample is from real data.
"""
param_attr = ParamAttr( param_attr = ParamAttr(
is_static=is_generator_training, initial_mean=0.0, initial_std=0.02) is_static=is_generator_training, initial_mean=0.0, initial_std=0.02)
bias_attr = ParamAttr( bias_attr = ParamAttr(
is_static=is_generator_training, initial_mean=0.0, initial_std=0.0) is_static=is_generator_training, initial_mean=0.0, initial_std=0.0)
param_attr_bn = ParamAttr( param_attr_bn = ParamAttr(
is_static=is_generator_training, initial_mean=1.0, initial_std=0.02) is_static=is_generator_training, initial_mean=1.0, initial_std=0.02)
...@@ -473,66 +458,16 @@ def discriminator(sample): ...@@ -473,66 +458,16 @@ def discriminator(sample):
act=SoftmaxActivation()) act=SoftmaxActivation())
``` ```
在文件`gan_conf.py`当中我们定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是分类器,**generator** 是生成器,**generator_training** 是生成器加分类器因为训练生成器时需要用到分类器提供目标函数。这个对应关系在下面这段代码中定义:
```python
if is_generator_training:
noise = data_layer(name="noise", size=noise_dim)
# 函数generator定义了生成器的结构
sample = generator(noise)
if is_discriminator_training:
sample = data_layer(name="sample", size=sample_dim)
if is_generator_training or is_discriminator_training:
label = data_layer(name="label", size=1)
函数discriminator定义了判别器的结构
prob = discriminator(sample)
cost = cross_entropy(input=prob, label=label)
classification_error_evaluator(
input=prob, label=label, name=mode + '_error')
outputs(cost)
if is_generator:
noise = data_layer(name="noise", size=noise_dim)
outputs(generator(noise))
```
## 训练模型 ## 训练模型
### 用Paddle API解析模型设置并创建trainer
为了能够训练在上面的模型配置文件中定义的网络,我们首先需要用Paddle API完成如下几个步骤:
用MNIST手写数字图片训练对抗式生成网络可以用如下的命令: 1. 初始化Paddle环境
2. 解析设置
```bash 3. 由设置创造GradientMachine以及由GradientMachine创造trainer
$python gan_trainer.py -d mnist --use_gpu 1
```
训练中打印的日志信息大致如下:
```
d_pos_loss is 0.681067 d_neg_loss is 0.704936
d_loss is 0.693001151085 g_loss is 0.681496
...........d_pos_loss is 0.64475 d_neg_loss is 0.667874
d_loss is 0.656311988831 g_loss is 0.719081
...
I0105 17:15:48.346783 20517 TrainerInternal.cpp:165] Batch=100 samples=12800 AvgCost=0.701575 CurrentCost=0.701575 Eval: generator_training_error=0.679219 CurrentEval: generator_training_error=0.679219
.........d_pos_loss is 0.644203 d_neg_loss is 0.71601
d_loss is 0.680106401443 g_loss is 0.671118
....
I0105 17:16:37.172737 20517 TrainerInternal.cpp:165] Batch=100 samples=12800 AvgCost=0.687359 CurrentCost=0.687359 Eval: discriminator_training_error=0.438359 CurrentEval: discriminator_training_error=0.438359
```
其中`d_pos_loss`是判别器对于真实数据判别真的负对数概率,`d_neg_loss`是判别器对于伪数据判别为假的负对数概率,`d_loss`是这两者的平均值。`g_loss`是伪数据被判别器判别为真的负对数概率。对于对抗式生成网络来说,最好的训练情况是D和G的能力比较相近,也就是`d_loss``g_loss`在训练的前几个pass中数值比较接近(-log(0.5) = 0.693)。由于G和D是轮流训练,所以它们各自每过100个batch,都会打印各自的训练信息。
为了能够训练在gan_conf.py中定义的网络,我们需要如下几个步骤:
1. 初始化Paddle环境,
2. 解析设置,
3. 由设置创造GradientMachine以及由GradientMachine创造trainer。
这几步分别由下面几段代码实现: 这几步分别由下面几段代码实现:
```python ```python
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
# 初始化Paddle环境 # 初始化Paddle环境
...@@ -540,12 +475,13 @@ api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10', ...@@ -540,12 +475,13 @@ api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10',
'--log_period=100', '--gpu_id=' + args.gpu_id, '--log_period=100', '--gpu_id=' + args.gpu_id,
'--save_dir=' + "./%s_params/" % data_source) '--save_dir=' + "./%s_params/" % data_source)
# 解析设置 # 解析设置:像上个小节提到的那样,gan的模型训练需要三个神经元网络
# 这里分别解析出三种模型设置
gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source) gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source)
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source) dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source)
generator_conf = parse_config(conf, "mode=generator,data=" + data_source) generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
# 由设置创造GradientMachine # 由模型设置创造GradientMachine
dis_training_machine = api.GradientMachine.createFromConfigProto( dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config) dis_conf.model_config)
gen_training_machine = api.GradientMachine.createFromConfigProto( gen_training_machine = api.GradientMachine.createFromConfigProto(
...@@ -558,7 +494,9 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) ...@@ -558,7 +494,9 @@ dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
``` ```
为了能够平衡生成器和分类器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过调用`GradientMachine``forward`方法来计算。
为了能够平衡生成器和判别器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过调用`GradientMachine``forward`方法来计算。
```python ```python
def get_training_loss(training_machine, inputs): def get_training_loss(training_machine, inputs):
...@@ -580,6 +518,29 @@ dis_training_machine) ...@@ -580,6 +518,29 @@ dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine) copy_shared_parameters(gen_training_machine, generator_machine)
``` ```
用MNIST手写数字图片训练对抗式生成网络可以用如下的命令。如果想用其他训练数据可以将参数`-d`改为uniform或者cifar。
```bash
$python gan_trainer.py -d mnist --use_gpu 1
```
训练中打印的日志信息大致如下:
```
d_pos_loss is 0.681067 d_neg_loss is 0.704936
d_loss is 0.693001151085 g_loss is 0.681496
...........d_pos_loss is 0.64475 d_neg_loss is 0.667874
d_loss is 0.656311988831 g_loss is 0.719081
...
I0105 17:15:48.346783 20517 TrainerInternal.cpp:165] Batch=100 samples=12800 AvgCost=0.701575 CurrentCost=0.701575 Eval: generator_training_error=0.679219 CurrentEval: generator_training_error=0.679219
.........d_pos_loss is 0.644203 d_neg_loss is 0.71601
d_loss is 0.680106401443 g_loss is 0.671118
....
I0105 17:16:37.172737 20517 TrainerInternal.cpp:165] Batch=100 samples=12800 AvgCost=0.687359 CurrentCost=0.687359 Eval: discriminator_training_error=0.438359 CurrentEval: discriminator_training_error=0.438359
```
其中`d_pos_loss`是判别器对于真实数据判别真的负对数概率,`d_neg_loss`是判别器对于伪数据判别为假的负对数概率,`d_loss`是这两者的平均值。`g_loss`是伪数据被判别器判别为真的负对数概率。对于对抗式生成网络来说,最好的训练情况是D和G的能力比较相近,也就是`d_loss``g_loss`在训练的前几个pass中数值比较接近(-log(0.5) = 0.693)。由于G和D是轮流训练,所以它们各自每过100个batch,都会打印各自的训练信息。
## 应用模型 ## 应用模型
图片由训练好的生成器生成。以下的代码将随机向量输入到模型 G,通过向前传递得到生成的图片。 图片由训练好的生成器生成。以下的代码将随机向量输入到模型 G,通过向前传递得到生成的图片。
...@@ -614,3 +575,6 @@ fake_samples = get_fake_samples(generator_machine, batch_size, noise) ...@@ -614,3 +575,6 @@ fake_samples = get_fake_samples(generator_machine, batch_size, noise)
3. Kingma D.P. and Welling M. [Auto-encoding variational bayes](https://arxiv.org/pdf/1312.6114v10.pdf)[C] arXiv preprint arXiv:1312.6114. 2013 3. Kingma D.P. and Welling M. [Auto-encoding variational bayes](https://arxiv.org/pdf/1312.6114v10.pdf)[C] arXiv preprint arXiv:1312.6114. 2013
4. Salakhutdinov R and Hinton G. [Deep Boltzmann Machines](http://www.jmlr.org/proceedings/papers/v5/salakhutdinov09a/salakhutdinov09a.pdf)[J] AISTATS. Vol. 1. 2009 4. Salakhutdinov R and Hinton G. [Deep Boltzmann Machines](http://www.jmlr.org/proceedings/papers/v5/salakhutdinov09a/salakhutdinov09a.pdf)[J] AISTATS. Vol. 1. 2009
5. Radford A, Metz L, Chintala S. [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2.pdf)[C] arXiv preprint arXiv:1511.06434. 2015 5. Radford A, Metz L, Chintala S. [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2.pdf)[C] arXiv preprint arXiv:1511.06434. 2015
6. Goodfellow I. [NIPS 2016 Tutorial: Generative Adversarial Networks] (https://arxiv.org/pdf/1701.00160v1.pdf) [C] arXiv preprint arXiv:1701.00160. 2016
7. Dumoulin V. and Visin F. [A guide to convolution arithmetic for deep learning] (https://arxiv.org/pdf/1603.07285v1.pdf). arXiv preprint arXiv:1603.07285. 2016
8. Kingma D., Ba J. [Adam: A method for stochastic optimization] (https://arxiv.org/pdf/1412.6980v8.pdf) arXiv preprint arXiv:1412.6980. 2014
...@@ -92,13 +92,17 @@ def conv_bn(input, ...@@ -92,13 +92,17 @@ def conv_bn(input,
nameApx = "_conv" nameApx = "_conv"
if bn: if bn:
conv_act = LinearActivation()
else:
conv_act = act
conv = img_conv_layer( conv = img_conv_layer(
input, input,
filter_size=filter_size, filter_size=filter_size,
num_filters=num_filters, num_filters=num_filters,
name=name + nameApx, name=name + nameApx,
num_channels=channels, num_channels=channels,
act=LinearActivation(), act=conv_act,
groups=1, groups=1,
stride=stride, stride=stride,
padding=padding, padding=padding,
...@@ -111,6 +115,7 @@ def conv_bn(input, ...@@ -111,6 +115,7 @@ def conv_bn(input,
padding_y=None, padding_y=None,
trans=trans) trans=trans)
if bn:
conv_bn = batch_norm_layer( conv_bn = batch_norm_layer(
conv, conv,
act=act, act=act,
...@@ -118,27 +123,8 @@ def conv_bn(input, ...@@ -118,27 +123,8 @@ def conv_bn(input,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr_bn, param_attr=param_attr_bn,
use_global_stats=False) use_global_stats=False)
return conv_bn return conv_bn
else: else:
conv = img_conv_layer(
input,
filter_size=filter_size,
num_filters=num_filters,
name=name + nameApx,
num_channels=channels,
act=act,
groups=1,
stride=stride,
padding=padding,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=True,
layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans)
return conv return conv
......
...@@ -25,24 +25,6 @@ import py_paddle.swig_paddle as api ...@@ -25,24 +25,6 @@ import py_paddle.swig_paddle as api
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot2DScatter(data, outputfile):
'''
Plot the data as a 2D scatter plot and save to outputfile
data needs to be two dimensinoal
'''
x = data[:, 0]
y = data[:, 1]
logger.info("The mean vector is %s" % numpy.mean(data, 0))
logger.info("The std vector is %s" % numpy.std(data, 0))
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight')
def CHECK_EQ(a, b): def CHECK_EQ(a, b):
assert a == b, "a=%s, b=%s" % (a, b) assert a == b, "a=%s, b=%s" % (a, b)
...@@ -80,6 +62,12 @@ def print_parameters(src): ...@@ -80,6 +62,12 @@ def print_parameters(src):
) )
# synthesize 2-D uniform data
def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32')
return data
def load_mnist_data(imageFile): def load_mnist_data(imageFile):
f = open(imageFile, "rb") f = open(imageFile, "rb")
f.read(16) f.read(16)
...@@ -111,10 +99,22 @@ def load_cifar_data(cifar_path): ...@@ -111,10 +99,22 @@ def load_cifar_data(cifar_path):
return data return data
# synthesize 2-D uniform data def plot2DScatter(data, outputfile):
def load_uniform_data(): '''
data = numpy.random.rand(1000000, 2).astype('float32') Plot the data as a 2D scatter plot and save to outputfile
return data data needs to be two dimensinoal
'''
x = data[:, 0]
y = data[:, 1]
logger.info("The mean vector is %s" % numpy.mean(data, 0))
logger.info("The std vector is %s" % numpy.std(data, 0))
heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight')
def merge(images, size): def merge(images, size):
...@@ -140,6 +140,13 @@ def save_images(images, path): ...@@ -140,6 +140,13 @@ def save_images(images, path):
im.save(path) im.save(path)
def save_results(samples, path, data_source):
if data_source == "uniform":
plot2DScatter(samples, path)
else:
save_images(samples, path)
def get_real_samples(batch_size, data_np): def get_real_samples(batch_size, data_np):
return data_np[numpy.random.choice( return data_np[numpy.random.choice(
data_np.shape[0], batch_size, replace=False), :] data_np.shape[0], batch_size, replace=False), :]
...@@ -210,9 +217,14 @@ def main(): ...@@ -210,9 +217,14 @@ def main():
parser.add_argument( parser.add_argument(
"--use_gpu", default="1", help="1 means use gpu for training") "--use_gpu", default="1", help="1 means use gpu for training")
parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter") parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter")
parser.add_argument(
"--model_dir",
default="",
help="model path for generating samples, empty means training mode")
args = parser.parse_args() args = parser.parse_args()
data_source = args.data_source data_source = args.data_source
use_gpu = args.use_gpu use_gpu = args.use_gpu
model_dir = args.model_dir
assert data_source in ["mnist", "cifar", "uniform"] assert data_source in ["mnist", "cifar", "uniform"]
assert use_gpu in ["0", "1"] assert use_gpu in ["0", "1"]
...@@ -237,6 +249,8 @@ def main(): ...@@ -237,6 +249,8 @@ def main():
dis_conf = parse_config(conf, dis_conf = parse_config(conf,
"mode=discriminator_training,data=" + data_source) "mode=discriminator_training,data=" + data_source)
generator_conf = parse_config(conf, "mode=generator,data=" + data_source) generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
logger.info(str(generator_conf.model_config))
batch_size = dis_conf.opt_config.batch_size batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise") noise_dim = get_layer_size(gen_conf.model_config, "noise")
...@@ -253,15 +267,21 @@ def main(): ...@@ -253,15 +267,21 @@ def main():
# this create a gradient machine for generator # this create a gradient machine for generator
gen_training_machine = api.GradientMachine.createFromConfigProto( gen_training_machine = api.GradientMachine.createFromConfigProto(
gen_conf.model_config) gen_conf.model_config)
# generator_machine is used to generate data only, which is used for # generator_machine is used to generate data only, which is used for
# training discriminator # training discriminator
logger.info(str(generator_conf.model_config))
generator_machine = api.GradientMachine.createFromConfigProto( generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config) generator_conf.model_config)
dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) # In the generating settings, use previously trained model to generate
# fake samples
if model_dir != "":
generator_machine.loadParameters(model_dir)
noise = get_noise(batch_size, noise_dim)
fake_samples = get_fake_samples(generator_machine, batch_size, noise)
save_results(fake_samples, "./generated_samples.png", data_source)
return
dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
dis_trainer.startTrain() dis_trainer.startTrain()
...@@ -325,8 +345,6 @@ def main(): ...@@ -325,8 +345,6 @@ def main():
curr_train = "gen" curr_train = "gen"
curr_strike = 1 curr_strike = 1
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
# TODO: add API for paddle to allow true parameter sharing between different GradientMachines
# so that we do not need to copy shared parameters.
copy_shared_parameters(gen_training_machine, copy_shared_parameters(gen_training_machine,
dis_training_machine) dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine) copy_shared_parameters(gen_training_machine, generator_machine)
...@@ -335,12 +353,8 @@ def main(): ...@@ -335,12 +353,8 @@ def main():
gen_trainer.finishTrainPass() gen_trainer.finishTrainPass()
# At the end of each pass, save the generated samples/images # At the end of each pass, save the generated samples/images
fake_samples = get_fake_samples(generator_machine, batch_size, noise) fake_samples = get_fake_samples(generator_machine, batch_size, noise)
if data_source == "uniform": save_results(fake_samples, "./%s_samples/train_pass%s.png" %
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass), data_source)
(data_source, train_pass))
else:
save_images(fake_samples, "./%s_samples/train_pass%s.png" %
(data_source, train_pass))
dis_trainer.finishTrain() dis_trainer.finishTrain()
gen_trainer.finishTrain() gen_trainer.finishTrain()
......
文件已添加
gan/image/gan_conf_graph.png

79.2 KB | W: | H:

gan/image/gan_conf_graph.png

125.7 KB | W: | H:

gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
gan/image/gan_conf_graph.png
  • 2-up
  • Swipe
  • Onion skin
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册