05.md 51.1 KB
Newer Older
W
wizardforcel 已提交
1 2 3 4
# 5

# 改进的 GAN

W
wizardforcel 已提交
5
自 2014 年引入**生成对抗网络****GAN**)以来,其流行度迅速提高。 GAN 已被证明是有用的生成模型,可以合成看起来真实的新数据。 深度学习中的许多研究论文都遵循提出的措施来解决原始 GAN 的困难和局限性。
W
wizardforcel 已提交
6

W
wizardforcel 已提交
7
正如我们在前几章中讨论的那样,众所周知,GAN 很难训练,并且易于崩溃。 模式损耗是一种情况,即使损失函数已经被优化,但发生器仍会产生看起来相同的输出。 在 MNIST 数字的情况下,模式折叠时,生成器可能只产生数字 4 和 9,因为看起来很相似。 **Wasserstein GAN****WGAN**)[2]解决了这些问题,认为只需替换基于 **Wasserstein** 的 GAN 损失函数就可以稳定的训练和避免模式崩溃,也称为**陆地移动距离****EMD**)。
W
wizardforcel 已提交
8

W
wizardforcel 已提交
9
但是,稳定性问题并不是 GAN 的唯一问题。 也越来越需要来提高所生成图像的感知质量。 **最小二乘 GAN****LSGAN**)[3]建议同时解决这两个问题。 基本前提是,在训练过程中,Sigmoid 交叉熵损失会导致梯度消失。 这导致较差的图像质量。 最小二乘损失不会导致梯度消失。 与香草 GAN 生成的图像相比,生成的生成图像具有更高的感知质量。
W
wizardforcel 已提交
10

W
wizardforcel 已提交
11
在上一章中,CGAN 介绍了一种调节发电机输出的方法。 例如,如果要获取数字 8,则可以在生成器的输入中包含条件标签。 受 CGAN 的启发,**辅助分类器 GAN****ACGAN**)[4]提出了一种改进的条件算法,可产生更好的感知质量和输出多样性。
W
wizardforcel 已提交
12 13 14 15 16 17 18 19 20 21 22 23

总之,本章的目的是介绍:

*   WGAN 的理论表述
*   对 LSGAN 原理的理解
*   对 ACGAN 原理的理解
*   改进的 GAN 的`tf.keras`实现– WGAN,LSGAN 和 ACGAN

让我们从讨论 WGAN 开始。

# 1\. Wasserstein GAN

W
wizardforcel 已提交
24
如前所述,众所周知,GAN 很难训练。 判别器和生成器这两个网络的相反目标很容易导致训练不稳定。 判别器尝试从真实数据中正确分类伪造数据。 同时,生成器将尽最大努力欺骗判别器。 如果判别器的学习速度比生成器快,则生成器参数将无法优化。 另一方面,如果判别器学习较慢,则梯度可能会在到达生成器之前消失。 在最坏的情况下,如果判别器无法收敛,则生成器将无法获得任何有用的反馈。
W
wizardforcel 已提交
25

W
wizardforcel 已提交
26
WGAN 认为 GAN 固有的不稳定性是由于它的损失函数引起的,该函数基于 **Jensen-Shannon****JS**)距离。 在 GAN 中,生成器的目的是学习如何将一种源分布(例如噪声)从转换为估计的目标分布(例如 MNIST 数字)。 使用 GAN 的原始公式,损失函数实际上是使目标分布与其估计值之间的距离最小。 问题是,对于某些分布对,没有平滑的路径可以最小化此 JS 距离。 因此,训练将无法收敛。
W
wizardforcel 已提交
27 28 29 30 31 32 33

在以下部分中,我们将研究三个距离函数,并分析什么可以替代更适合 GAN 优化的 JS 距离函数。

## 距离功能

可以通过检查其损失函数来了解训练 GAN 的稳定性。 为了更好地理解 GAN 损失函数,我们将回顾两个概率分布之间的公共距离或散度函数。

W
wizardforcel 已提交
34
我们关注的是用于真实数据分配的`p_data`与用于发电机数据分配的`p_g`之间的距离。 GAN 的目标是制造![](img/B14853_05_001.png)。“表 5.1.1”显示了散度函数。
W
wizardforcel 已提交
35

W
wizardforcel 已提交
36
在大多数个最大似然任务中,我们将使用 **Kullback-Leibler****KL**)散度,或`D[KL]`损失函数可以衡量我们的神经网络模型预测与真实分布函数之间的距离。 如“公式 5.1.1”所示,由于![](img/B14853_05_002.png),所以`D[KL]`不对称。
W
wizardforcel 已提交
37

W
wizardforcel 已提交
38
**JS**`D[JS]`是基于`D[KL]`的差异。 但是,与`D[KL]`不同,`D[JS]`是对称的并且是有限的。 在本节中,我们将演示优化 GAN 损失函数等同于优化`D[JS]`
W
wizardforcel 已提交
39 40

<colgroup><col> <col></colgroup> 
W
wizardforcel 已提交
41
| **散度** | **表达式** |
W
wizardforcel 已提交
42
| Kullback-Leibler(KL)“公式 5.1.1” | 
W
wizardforcel 已提交
43

W
wizardforcel 已提交
44
![](img/B14853_05_003.png)
W
wizardforcel 已提交
45

W
wizardforcel 已提交
46
![](img/B14853_05_004.png)
W
wizardforcel 已提交
47 48

 |
W
wizardforcel 已提交
49
| *詹森·香农(JS)“公式 5.1.2” | 
W
wizardforcel 已提交
50

W
wizardforcel 已提交
51
![](img/B14853_05_005.png)
W
wizardforcel 已提交
52 53

 |
W
wizardforcel 已提交
54
| 陆地移动距离(EMD)或 Wasserstein 1 “公式 5.1.3” | 
W
wizardforcel 已提交
55

W
wizardforcel 已提交
56
![](img/B14853_05_006.png)
W
wizardforcel 已提交
57

W
wizardforcel 已提交
58
其中![](img/B14853_05_007.png)是所有联合分布![](img/B14853_05_008.png)的集合,其边际为`p_data``p_g`
W
wizardforcel 已提交
59

W
wizardforcel 已提交
60
表 5.1.1:两个概率分布函数`p_data``p_g`之间的散度函数
W
wizardforcel 已提交
61

W
wizardforcel 已提交
62
EMD 背后的想法是,它是为了确定概率分布`p_data`应由![](img/B14853_05_010.png)传输多少质量![](img/B14853_05_009.png)`p_g`。 ![](img/B14853_05_011.png)是所有可能的关节分布[H​​TG11]中的关节分布。 ![](img/B14853_05_013.png)也被称为运输计划,以反映运输质量以匹配两个概率分布的策略。 给定两个概率分布,有许多可能的运输计划。 大致而言, `inf`表示成本最低的运输计划。
W
wizardforcel 已提交
63

W
wizardforcel 已提交
64
例如,“图 5.1.1”向我们展示了两个简单的离散分布`x``y`
W
wizardforcel 已提交
65

W
wizardforcel 已提交
66
![](img/B14853_05_01.png)
W
wizardforcel 已提交
67

W
wizardforcel 已提交
68
图 5.1.1:EMD 是从`x`传输以匹配目标分布`y`的质量的加权数量。
W
wizardforcel 已提交
69

W
wizardforcel 已提交
70
在位置`i = 1, 2, 3, 4`上,`x`在具有质量`m[i], i = 1, 2, 3, 4`。同时,位置`y[i], i = 1, 2`上,`y`的质量为`m[i], i = 1, 2`。为了匹配分布`y`,图中的箭头显示了将每个质量`x[i]`移动`d[i]`的最小运输计划。 EMD 计算如下:
W
wizardforcel 已提交
71

W
wizardforcel 已提交
72
![](img/B14853_05_014.png) (Equation 5.1.4)
W
wizardforcel 已提交
73

W
wizardforcel 已提交
74
在“图 5.1.1”中,EMD 可解释为移动一堆污物`x`填充孔`y`所需的最少工作量。 尽管在此示例中,也可以从图中推导出`inf`,但在大多数情况下,尤其是在连续分布中,用尽所有可能的运输计划是很棘手的。 我们将在本章中稍后回到这个问题。 同时,我们将向您展示 GAN 损失函数的作用,实际上是如何使 **JS** 的差异最小化。
W
wizardforcel 已提交
75 76 77

## GAN 中的距离函数

W
wizardforcel 已提交
78
现在,在上一章的损失函数给定任何生成器的情况下,我们将计算最佳判别器。 我们将回顾上一章中的以下等式:
W
wizardforcel 已提交
79

W
wizardforcel 已提交
80
![](img/B14853_05_015.png) (Equation 4.1.1)
W
wizardforcel 已提交
81 82 83

除了从噪声分布中采样外,前面的等式也可以表示为从发生器分布中采样:

W
wizardforcel 已提交
84
![](img/B14853_05_016.png) (Equation 5.1.5)
W
wizardforcel 已提交
85 86 87

找出最小的![](img/B14853_05_017.png)

W
wizardforcel 已提交
88
![](img/B14853_05_018.png) (Equation 5.1.6)
W
wizardforcel 已提交
89

W
wizardforcel 已提交
90
![](img/B14853_05_019.png) (Equation 5.1.7)
W
wizardforcel 已提交
91 92 93

积分内部的术语为![](img/B14853_05_020.png)的形式,对于不包括![](img/B14853_05_024.png)的任何![](img/B14853_05_023.png),在![](img/B14853_05_022.png)的![](img/B14853_05_021.png)处都有一个已知的最大值。 由于该积分不会更改此表达式的最大值(或![](img/B14853_05_025.png)的最小值)的位置,因此最佳判别器为:

W
wizardforcel 已提交
94
![](img/B14853_05_026.png) (Equation 5.1.8)
W
wizardforcel 已提交
95 96 97

因此,给定最佳判别器的损失函数为:

W
wizardforcel 已提交
98
![](img/B14853_05_027.png) (Equation 5.1.9)
W
wizardforcel 已提交
99

W
wizardforcel 已提交
100
![](img/B14853_05_028.png) (Equation 5.1.10)
W
wizardforcel 已提交
101

W
wizardforcel 已提交
102
![](img/B14853_05_029.png) (Equation 5.1.11)
W
wizardforcel 已提交
103

W
wizardforcel 已提交
104
![](img/B14853_05_030.png) (Equation 5.1.12)
W
wizardforcel 已提交
105

W
wizardforcel 已提交
106
我们可以从“公式 5.1.12”观察到,最佳判别器的损失函数为常数减去真实分布![](img/B14853_05_031.png)和任何发生器分布`p_g`之间的 JS 散度的两倍。 最小化![](img/B14853_05_032.png)意味着最大化![](img/B14853_05_033.png),否则鉴别者必须正确地将真实数据中的伪造物分类。
W
wizardforcel 已提交
107 108 109

同时,我们可以放心地说,最佳生成器是当生成器分布等于真实数据分布时:

W
wizardforcel 已提交
110
![](img/B14853_05_034.png) (Equation 5.1.13)
W
wizardforcel 已提交
111

W
wizardforcel 已提交
112
这是有道理的,因为生成器的目的是通过学习真实的数据分布来欺骗判别器。 有效地,我们可以通过最小化`D[JS]`或通过制作![](img/B14853_05_035.png)来获得最佳生成器。 给定最佳生成器,最佳判别器为![](img/B14853_05_036.png)和![](img/B14853_05_037.png)
W
wizardforcel 已提交
113 114 115

问题在于,当两个分布没有重叠时,就没有平滑函数可以帮助缩小它们之间的差距。 训练 GAN 不会因梯度下降而收敛。 例如,假设:

W
wizardforcel 已提交
116
![](img/B14853_05_038.png)=![](img/B14853_05_039.png) where ![](img/B14853_05_040.png) (Equation 5.1.14)
W
wizardforcel 已提交
117

W
wizardforcel 已提交
118
![](img/B14853_05_041.png)=![](img/B14853_05_042.png) where ![](img/B14853_05_043.png) (Equation 5.1.15)
W
wizardforcel 已提交
119

W
wizardforcel 已提交
120
这两个分布显示在“图 5.1.2”中:
W
wizardforcel 已提交
121

W
wizardforcel 已提交
122
![](img/B14853_05_02.png)
W
wizardforcel 已提交
123

W
wizardforcel 已提交
124
图 5.1.2:没有重叠的两个分布的示例。 对于`p_g`,θ= 0.5
W
wizardforcel 已提交
125 126 127 128 129 130 131 132

![](img/B14853_05_044.png)是均匀分布。 每个距离函数的差异如下:

*   ![](img/B14853_05_045.png)
*   ![](img/B14853_05_046.png)
*   ![](img/B14853_05_047.png)
*   ![](img/B14853_05_048.png)

W
wizardforcel 已提交
133
由于`D[JS]`是一个常数,因此 GAN 将没有足够的梯度来驱动![](img/B14853_05_049.png)。 我们还会发现`D[KL]`或反向`D[KL]`也不起作用。 但是,通过`W(p_data, p_g)`,我们可以使平滑 通过梯度下降获得![](img/B14853_05_050.png)。 为了优化 GAN,EMD 或 Wasserstein 1 似乎是一个更具逻辑性的损失函数,因为在两个分布具有极小或没有重叠的情况下,`D[JS]`会失败。
W
wizardforcel 已提交
134

W
wizardforcel 已提交
135
为了帮助进一步理解,可以在以下位置找到[有关距离函数的精彩讨论](https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html)
W
wizardforcel 已提交
136 137 138 139 140 141 142

在下一节中,我们将重点介绍使用 EMD 或 Wasserstein 1 距离函数来开发替代损失函数,以鼓励稳定训练 GAN。

## 使用 Wasserstein 损失

在使用 EMD 或 Wasserstein 1 之前,还有一个要解决的问题。 耗尽![](img/B14853_05_051.png)的空间来找到![](img/B14853_05_052.png)是很棘手的。 提出的解决方案是使用其 Kantorovich-Rubinstein 对偶:

W
wizardforcel 已提交
143
![](img/B14853_05_053.png) (Equation 5.1.16)
W
wizardforcel 已提交
144

W
wizardforcel 已提交
145
等效地,EMD ![](img/B14853_05_054.png)是所有 K-Lipschitz 函数上的最高值(大约是最大值):![](img/B14853_05_055.png)。 K-Lipschitz 函数满足以下约束:
W
wizardforcel 已提交
146

W
wizardforcel 已提交
147
![](img/B14853_05_056.png) (Equation 5.1.17)
W
wizardforcel 已提交
148

W
wizardforcel 已提交
149
对于所有![](img/B14853_05_057.png)。 K-Lipschitz 函数具有有界导数,并且几乎总是连续可微的(例如,![](img/B14853_05_058.png)具有有界导数并且是连续的,但在`x = 0`时不可微分)。
W
wizardforcel 已提交
150

W
wizardforcel 已提交
151
“公式 5.1.16”可以通过找到 K-Lipschitz 函数![](img/B14853_05_059.png)的族来求解:
W
wizardforcel 已提交
152

W
wizardforcel 已提交
153
![](img/B14853_05_060.png) (Equation 5.1.18)
W
wizardforcel 已提交
154

W
wizardforcel 已提交
155
在 GAN 中,可以通过从`z`-噪声分布采样并用`f[w]`替换“公式 5.1.18”来重写。 鉴别函数,`D[w]`
W
wizardforcel 已提交
156

W
wizardforcel 已提交
157
![](img/B14853_05_061.png) (Equation 5.1.19)
W
wizardforcel 已提交
158

W
wizardforcel 已提交
159
我们使用粗体字母突出显示多维样本的一般性。 最后一个问题是如何找到功能族![](img/B14853_05_062.png)。 所提出的解决方案是在每次梯度更新时进行的。 判别器`w`的权重被限制在上下限之间(例如,-0.01 和 0.01):
W
wizardforcel 已提交
160

W
wizardforcel 已提交
161
![](img/B14853_05_063.png) (Equation 5.1.20)
W
wizardforcel 已提交
162

W
wizardforcel 已提交
163
`w`的较小值将判别器约束到紧凑的参数空间,从而确保 Lipschitz 连续性。
W
wizardforcel 已提交
164

W
wizardforcel 已提交
165
我们可以使用“公式 5.1.19”作为我们新的 GAN 损失函数的基础。 EMD 或 Wasserstein 1 是生成器旨在最小化的损失函数,以及判别器试图最大化的损失函数(或最小化`-W(p_data, p_g)`
W
wizardforcel 已提交
166

W
wizardforcel 已提交
167
![](img/B14853_05_064.png) (Equation 5.1.21)
W
wizardforcel 已提交
168

W
wizardforcel 已提交
169
![](img/B14853_05_065.png) (Equation 5.1.22)
W
wizardforcel 已提交
170 171 172

在发电机损失函数中,第一项消失了,因为它没有针对实际数据进行直接优化。

W
wizardforcel 已提交
173
“表 5.1.2”显示了 GAN 和 WGAN 的损失函数之间的差异。 为简洁起见,我们简化了![](img/B14853_05_066.png)和![](img/B14853_05_067.png)的表示法:
W
wizardforcel 已提交
174 175 176 177 178

<colgroup><col> <col> <col></colgroup> 
| **网络** | **损失函数** | **公式** |
| GAN | 

W
wizardforcel 已提交
179
![](img/B14853_05_068.png)
W
wizardforcel 已提交
180

W
wizardforcel 已提交
181
![](img/B14853_05_069.png)
W
wizardforcel 已提交
182 183 184 185

 | 4.1.14.1.5 |
| WGAN | 

W
wizardforcel 已提交
186
![](img/B14853_05_070.png)
W
wizardforcel 已提交
187

W
wizardforcel 已提交
188
![](img/B14853_05_071.png)
W
wizardforcel 已提交
189

W
wizardforcel 已提交
190
![](img/B14853_05_072.png)
W
wizardforcel 已提交
191 192 193 194 195

 | 5.1.215.1.225.1.20 |

表 5.1.2:GAN 和 WGAN 的损失函数之间的比较

W
wizardforcel 已提交
196
这些损失功能用于训练 WGAN,如“算法 5.1.1”中所示。
W
wizardforcel 已提交
197

W
wizardforcel 已提交
198
**算法 5.1.1 WGAN**。 参数的值为![](img/B14853_05_073.png),![](img/B14853_05_074.png),![](img/B14853_05_075.png)和![](img/B14853_05_076.png)
W
wizardforcel 已提交
199

W
wizardforcel 已提交
200
要求:![](img/B14853_05_077.png),学习率。`c`是削波参数。`m`,批量大小。 ![](img/B14853_05_078.png),即每个生成器迭代的评论(鉴别)迭代次数。
W
wizardforcel 已提交
201 202 203 204 205 206 207

要求:![](img/B14853_05_079.png),初始注释器(discriminator)参数。 ![](img/B14853_05_080.png),初始发电机参数:

1.  `while` ![](img/B14853_05_081.png)尚未收敛`do`
2.  `for` ![](img/B14853_05_082.png) `do`
3.  从真实数据中抽样一批![](img/B14853_05_083.png)
4.  从均匀的噪声分布中采样一批![](img/B14853_05_084.png)
W
wizardforcel 已提交
208 209
5.  ![](img/B14853_05_085.png),计算判别器梯度
6.  ![](img/B14853_05_086.png),更新判别器参数
W
wizardforcel 已提交
210
7.  ![](img/B14853_05_087.png),剪辑判别器权重
W
wizardforcel 已提交
211 212 213 214 215 216
8.  `end for`
9.  从均匀的噪声分布中采样一批![](img/B14853_05_088.png)
10.  ![](img/B14853_05_089.png),计算生成器梯度
11.  ![](img/B14853_05_090.png),更新生成器参数
12.  `end while`

W
wizardforcel 已提交
217
“图 5.1.3”展示了 WGAN 模型实际上与 DCGAN 相同,除了伪造的/真实的数据标签和损失函数:
W
wizardforcel 已提交
218

W
wizardforcel 已提交
219
![](img/B14853_05_03.png)
W
wizardforcel 已提交
220

W
wizardforcel 已提交
221
图 5.1.3:顶部:训练 WGAN 判别器需要来自生成器的虚假数据和来自真实分发的真实数据。 下:训练 WGAN 生成器要求生成器中假冒的真实数据是真实的
W
wizardforcel 已提交
222

W
wizardforcel 已提交
223
与 GAN 相似,WGAN 交替训练判别器和生成器(通过对抗)。 但是,在 WGAN 中,判别器(也称为评论者)在训练生成器进行一次迭代(第 9 至 11 行)之前,先训练`n_critic`迭代(第 2 至 8 行)。 这与对于判别器和生成器具有相同数量的训练迭代的 GAN 相反。 换句话说,在 GAN 中,`n_critic = 1`
W
wizardforcel 已提交
224

W
wizardforcel 已提交
225
训练判别器意味着学习判别器的参数(权重和偏差)。 这需要从真实数据中采样一批(第 3 行),并从伪数据中采样一批(第 4 行),然后将采样数据馈送到判别器网络,然后计算判别器参数的梯度(第 5 行)。 判别器参数使用 RMSProp(第 6 行)进行了优化。 第 5 行和第 6 行都是“公式 5.1.21”的优化。
W
wizardforcel 已提交
226

W
wizardforcel 已提交
227
最后,EM 距离优化中的 Lipschitz 约束是通过裁剪判别器参数(第 7 行)来施加的。 第 7 行是“公式 5.1.20”的实现。 在`n_critic`迭代判别器训练之后,判别器参数被冻结。 生成器训练通过对一批伪造数据进行采样开始(第 9 行)。 采样的数据被标记为实数(1.0),以致愚弄判别器网络。 在第 10 行中计算生成器梯度,并在第 11 行中使用 RMSProp 对其进行优化。第 10 行和第 11 行执行梯度更新以优化“公式 5.1.22”。
W
wizardforcel 已提交
228

W
wizardforcel 已提交
229
训练生成器后,将解冻判别器参数,并开始另一个`n_critic`判别器训练迭代。 我们应该注意,在判别器训练期间不需要冻结生成器参数,因为生成器仅涉及数据的制造。 类似于 GAN,可以将判别器训练为一个单独的网络。 但是,训练发电机始终需要鉴别者通过对抗网络参与,因为损失是根据发电机网络的输出计算得出的。
W
wizardforcel 已提交
230

W
wizardforcel 已提交
231
与 GAN 不同,在 WGAN 中,将实际数据标记为 1.0,而将伪数据标记为 -1.0,作为计算第 5 行中的梯度的一种解决方法。第 5-6 和 10-11 行执行梯度更新以优化“公式 5.1.21”和“5.1.22”。 第 5 行和第 10 行中的每一项均建模为:
W
wizardforcel 已提交
232

W
wizardforcel 已提交
233
![](img/B14853_05_091.png) (Equation 5.1.23)
W
wizardforcel 已提交
234

W
wizardforcel 已提交
235
对于真实数据,其中`y_label = 1.0`,对于假数据,`y_label= -1.0`。 为了简化符号,我们删除了上标`(i)`。 对于判别器,当使用实际数据进行训练时,WGAN 增加![](img/B14853_05_092.png)以最小化损失函数。
W
wizardforcel 已提交
236

W
wizardforcel 已提交
237
使用伪造数据进行训练时,WGAN 会降低![](img/B14853_05_093.png)以最大程度地减少损失函数。 对于生成器,当在训练过程中将伪数据标记为真实数据时,WGAN 增加![](img/B14853_05_094.png)以最小化损失功能。 请注意,`y_label`除了其符号外,对损失功能没有直接贡献。 在`tf.keras`中,“公式 5.1.23”实现为:
W
wizardforcel 已提交
238 239 240 241 242 243

```py
def wasserstein_loss(y_label, y_pred):
    return -K.mean(y_label * y_pred) 
```

W
wizardforcel 已提交
244
本节最重要的部分是用于稳定训练 GAN 的新损失函数。 它基于 EMD 或 Wasserstein1。“算法 5.1.1”形式化了 WGAN 的完整训练算法,包括损失函数。 在下一节中,将介绍`tf.keras`中训练算法的实现。
W
wizardforcel 已提交
245

W
wizardforcel 已提交
246
## 使用 Keras 的 WGAN 实现
W
wizardforcel 已提交
247 248 249 250 251 252

为了在`tf.keras`中实现 WGAN,我们可以重用 GAN 的 DCGAN 实现,这是我们在上一一章中介绍的。 DCGAN 构建器和实用程序功能在`lib`文件夹的`gan.py`中作为模块实现。

功能包括:

*   `generator()`:生成器模型构建器
W
wizardforcel 已提交
253
*   `discriminator()`:判别器模型构建器
W
wizardforcel 已提交
254
*   `train()`:DCGAN 训练师
W
wizardforcel 已提交
255 256 257
*   `plot_images()`:通用发生器输出绘图仪
*   `test_generator()`:通用的发电机测试实用程序

W
wizardforcel 已提交
258
如“列表 5.1.1”所示,我们可以通过简单地调用以下命令来构建一个判别器:
W
wizardforcel 已提交
259 260 261 262 263 264 265 266 267 268 269

```py
discriminator = gan.discriminator(inputs, activation='linear') 
```

WGAN 使用线性输出激活。 对于生成器,我们执行:

```py
generator = gan.generator(inputs, image_size) 
```

W
wizardforcel 已提交
270
`tf.keras`中的整体网络模型类似于 DCGAN 的“图 4.2.1”中看到的模型。
W
wizardforcel 已提交
271

W
wizardforcel 已提交
272
“列表 5.1.1”突出显示了 RMSprop 优化器和 Wasserstein 损失函数的使用。 在训练期间使用“算法 5.1.1”中的超参数。
W
wizardforcel 已提交
273

W
wizardforcel 已提交
274
[完整的代码可在 GitHub 上获得](https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Keras)
W
wizardforcel 已提交
275

W
wizardforcel 已提交
276
“列表 5.1.1”:`wgan-mnist-5.1.2.py`
W
wizardforcel 已提交
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354

```py
def build_and_train_models():
    """Load the dataset, build WGAN discriminator,
    generator, and adversarial models.
    Call the WGAN train routine.
    """
    # load MNIST dataset
    (x_train, _), (_, _) = mnist.load_data() 
```

```py
 # reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255 
```

```py
 model_name = "wgan_mnist"
    # network parameters
    # the latent or z vector is 100-dim
    latent_size = 100
    # hyper parameters from WGAN paper [2]
    n_critic = 5
    clip_value = 0.01
    batch_size = 64
    lr = 5e-5
    train_steps = 40000
    input_shape = (image_size, image_size, 1) 
```

```py
 # build discriminator model
    inputs = Input(shape=input_shape, name='discriminator_input')
    # WGAN uses linear activation in paper [2]
    discriminator = gan.discriminator(inputs, activation='linear')
    optimizer = RMSprop(lr=lr)
    # WGAN discriminator uses wassertein loss
    discriminator.compile(loss=wasserstein_loss,
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary() 
```

```py
 # build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = gan.generator(inputs, image_size)
    generator.summary() 
```

```py
 # build adversarial model = generator + discriminator
    # freeze the weights of discriminator during adversarial training
    discriminator.trainable = False
    adversarial = Model(inputs,
                        discriminator(generator(inputs)),
                        name=model_name)
    adversarial.compile(loss=wasserstein_loss,
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary() 
```

```py
 # train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    params = (batch_size,
              latent_size,
              n_critic,
              clip_value,
              train_steps,
              model_name)
    train(models, x_train, params) 
```

W
wizardforcel 已提交
355
“列表 5.1.2”是紧跟“算法 5.1.1”的训练函数。 但是,在判别器的训练中有一个小的调整。 与其在单个合并的真实数据和虚假数据中组合训练权重,不如先训练一批真实数据,然后再训练一批虚假数据。 这种调整将防止梯度消失,因为真实和伪造数据标签中的符号相反,并且由于裁剪而导致的权重较小。
W
wizardforcel 已提交
356

W
wizardforcel 已提交
357
“列表 5.1.2”:`wgan-mnist-5.1.2.py`
W
wizardforcel 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494

为 WGAN 训练算法:

```py
def train(models, x_train, params):
    """Train the Discriminator and Adversarial Networks 
```

```py
 Alternately train Discriminator and Adversarial
    networks by batch.
    Discriminator is trained first with properly labelled
    real and fake images for n_critic times.
    Discriminator weights are clipped as a requirement 
    of Lipschitz constraint.
    Generator is trained next (via Adversarial) with 
    fake images pretending to be real.
    Generate sample images per save_interval 
```

```py
 Arguments:
        models (list): Generator, Discriminator,
            Adversarial models
        x_train (tensor): Train images
        params (list) : Networks parameters 
```

```py
 """
    # the GAN models
    generator, discriminator, adversarial = models
    # network parameters
    (batch_size, latent_size, n_critic,
            clip_value, train_steps, model_name) = params
    # the generator image is saved every 500 steps
    save_interval = 500
    # noise vector to see how the 
    # generator output evolves during training
    noise_input = np.random.uniform(-1.0,
                                    1.0,
                                    size=[16, latent_size])
    # number of elements in train dataset
    train_size = x_train.shape[0]
    # labels for real data
    real_labels = np.ones((batch_size, 1))
    for i in range(train_steps):
        # train discriminator n_critic times
        loss = 0
        acc = 0
        for _ in range(n_critic):
            # train the discriminator for 1 batch
            # 1 batch of real (label=1.0) and 
            # fake images (label=-1.0)
            # randomly pick real images from dataset
            rand_indexes = np.random.randint(0,
                                             train_size,
                                             size=batch_size)
            real_images = x_train[rand_indexes]
            # generate fake images from noise using generator
            # generate noise using uniform distribution
            noise = np.random.uniform(-1.0,
                                      1.0,
                                      size=[batch_size, latent_size])
            fake_images = generator.predict(noise) 
```

```py
 # train the discriminator network
            # real data label=1, fake data label=-1
            # instead of 1 combined batch of real and fake images,
            # train with 1 batch of real data first, then 1 batch
            # of fake images.
            # this tweak prevents the gradient 
            # from vanishing due to opposite
            # signs of real and fake data labels (i.e. +1 and -1) and 
            # small magnitude of weights due to clipping.
            real_loss, real_acc = \
                discriminator.train_on_batch(real_images,
                                             real_labels)
            fake_loss, fake_acc = \
                discriminator.train_on_batch(fake_images,
                                             -real_labels)
            # accumulate average loss and accuracy
            loss += 0.5 * (real_loss + fake_loss)
            acc += 0.5 * (real_acc + fake_acc)
            # clip discriminator weights to satisfy Lipschitz constraint
            for layer in discriminator.layers:
                weights = layer.get_weights()
                weights = [np.clip(weight,
                                   -clip_value,
                                   clip_value) for weight in weights]
                layer.set_weights(weights) 
```

```py
 # average loss and accuracy per n_critic training iterations
        loss /= n_critic
        acc /= n_critic
        log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc) 
```

```py
 # train the adversarial network for 1 batch
        # 1 batch of fake images with label=1.0
        # since the discriminator weights are frozen in 
        # adversarial network only the generator is trained
        # generate noise using uniform distribution
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        # train the adversarial network
        # note that unlike in discriminator training,
        # we do not save the fake images in a variable
        # the fake images go to the discriminator 
        # input of the adversarial for classification
        # fake images are labelled as real
        # log the loss and accuracy
        loss, acc = adversarial.train_on_batch(noise, real_labels)
        log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)
        print(log)
        if (i + 1) % save_interval == 0:
            # plot generator images on a periodic basis
            gan.plot_images(generator,
                            noise_input=noise_input,
                            show=False,
                            step=(i + 1),
                            model_name=model_name) 
```

```py
 # save the model after training the generator
    # the trained generator can be reloaded 
    # for future MNIST digit generation
    generator.save(model_name + ".h5") 
```

W
wizardforcel 已提交
495
“图 5.1.4”显示了 MNIST 数据集上 WGAN 输出的演变:
W
wizardforcel 已提交
496

W
wizardforcel 已提交
497
![](img/B14853_05_04.png)
W
wizardforcel 已提交
498

W
wizardforcel 已提交
499
图 5.1.4:WGAN 与训练步骤的示例输出。 在训练和测试期间,WGAN 的任何输出均不会遭受模式崩溃
W
wizardforcel 已提交
500 501 502

即使在网络配置更改的情况下,WGAN 也稳定。 例如,当在识别符网络的 ReLU 之前插入批量规范化时,已知 DCGAN 不稳定。 在 WGAN 中,相同的配置是稳定的。

W
wizardforcel 已提交
503
下图“图 5.1.5”向我们展示了 DCGAN 和 WGAN 的输出,并在判别器网络上进行了批量归一化:
W
wizardforcel 已提交
504

W
wizardforcel 已提交
505
![](img/B14853_05_05.png)
W
wizardforcel 已提交
506

W
wizardforcel 已提交
507
图 5.1.5:在判别器网络中的 ReLU 激活之前插入批量归一化时,DCGAN(左)和 WGAN(右)的输出比较
W
wizardforcel 已提交
508 509 510 511 512 513 514

与上一章中的 GAN 训练相似,经过 40,000 个训练步骤,将训练后的模型保存在文件中。 使用训练有素的生成器模型,通过运行以下命令来生成新的合成 MNIST 数字图像:

```py
python3 wgan-mnist-5.1.2.py --generator=wgan_mnist.h5 
```

W
wizardforcel 已提交
515
正如我们所讨论的,原始 GAN 很难训练。 当 GAN 优化的损失函数时,就会出现问题。 实际上是在优化 *JS* 差异,`D[JS]`。 当两个分布函数之间几乎没有重叠时,很难优化`D[JS]`
W
wizardforcel 已提交
516 517 518 519 520

WGAN 提出通过使用 EMD 或 Wasserstein 1 损失函数来解决该问题,该函数即使在两个分布之间很少或没有重叠时也具有平滑的微分函数。 但是,WGAN 与生成的图像质量无关。 除了稳定性问题之外,原始 GAN 生成的图像在感知质量方面还有很多改进的地方。 LSGAN 理论上可以同时解决两个问题。 在下一节中,我们将介绍 LSGAN。

# 2.最小二乘 GAN(LSGAN)

W
wizardforcel 已提交
521
LSGAN 提出最小二乘损失。“图 5.2.1”演示了为什么在 GAN 中使用 Sigmoid交叉熵损失会导致生成的数据质量较差:
W
wizardforcel 已提交
522

W
wizardforcel 已提交
523
![](img/B14853_05_06.png)
W
wizardforcel 已提交
524

W
wizardforcel 已提交
525
图 5.2.1:真实样本和虚假样本分布均除以各自的决策边界:Sigmoid 和最小二乘
W
wizardforcel 已提交
526 527 528 529 530

理想情况下,假样本分布应尽可能接近真实样本的分布。 但是,对于 GAN,一旦伪样本已经位于决策边界的正确一侧,梯度就消失了。

这会阻止生成器具有足够的动机来提高生成的伪数据的质量。 远离决策边界的伪样本将不再试图靠近真实样本的分布。 使用最小二乘损失函数,只要假样本分布与真实样本的分布相距甚远,梯度就不会消失。 即使假样本已经位于决策边界的正确一侧,生成器也将努力改善其对实际密度分布的估计。

W
wizardforcel 已提交
531
“表 5.2.1”显示了 GAN,WGAN 和 LSGAN 之间的损耗函数的比较:
W
wizardforcel 已提交
532 533 534 535 536

<colgroup><col> <col> <col></colgroup> 
| **网络** | **损失函数** | **公式** |
| GAN | 

W
wizardforcel 已提交
537
![](img/B14853_05_095.png)
W
wizardforcel 已提交
538

W
wizardforcel 已提交
539
![](img/B14853_05_096.png)
W
wizardforcel 已提交
540 541 542 543

 | 4.1.14.1.5 |
| WGAN | 

W
wizardforcel 已提交
544
![](img/B14853_05_097.png)
W
wizardforcel 已提交
545

W
wizardforcel 已提交
546
![](img/B14853_05_098.png)
W
wizardforcel 已提交
547

W
wizardforcel 已提交
548
![](img/B14853_05_099.png)
W
wizardforcel 已提交
549 550 551 552

 | 5.1.215.1.225.1.20 |
| LSGAN | 

W
wizardforcel 已提交
553
![](img/B14853_05_100.png)
W
wizardforcel 已提交
554

W
wizardforcel 已提交
555
![](img/B14853_05_101.png)
W
wizardforcel 已提交
556 557 558 559 560

 | 5.2.15.2.2 |

表 5.2.1:GAN,WGAN 和 LSGAN 损失函数之间的比较

W
wizardforcel 已提交
561
最小化“公式 5.2.1”或判别器损失函数意味着实际数据分类与真实标签 1.0 之间的 MSE 应该接近零。 此外,假数据分类和真实标签 0.0 之间的 MSE 应该接近零。
W
wizardforcel 已提交
562

W
wizardforcel 已提交
563
与其他 GAN 相似,对 LSGAN 判别器进行了训练,可以从假数据样本中对真实数据进行分类。 最小化公式 5.2.2 意味着在标签 1.0 的帮助下,使判别器认为生成的假样本数据是真实的。
W
wizardforcel 已提交
564

W
wizardforcel 已提交
565
以上一章中的 DCGAN 代码为基础来实现 LSGAN 仅需进行一些更改。 如“列表 5.2.1”所示,删除了判别器 Sigmoid 激活。 判别器是通过调用以下命令构建的:
W
wizardforcel 已提交
566 567 568 569 570 571 572 573 574 575 576

```py
discriminator = gan.discriminator(inputs, activation=None) 
```

生成器类似于原始的 DCGAN:

```py
generator = gan.generator(inputs, image_size) 
```

W
wizardforcel 已提交
577
鉴别函数和对抗损失函数都被`mse`代替。 所有网络参数均与 DCGAN 中的相同。 `tf.keras`中 LSGAN 的网络模型类似于“图 4.2.1”,除了存在线性激活或无输出激活外。 训练过程类似于 DCGAN 中的训练过程,由实用程序功能提供:
W
wizardforcel 已提交
578 579 580 581 582

```py
gan.train(models, x_train, params) 
```

W
wizardforcel 已提交
583
“列表 5.2.1”:`lsgan-mnist-5.2.1.py`
W
wizardforcel 已提交
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641

```py
def build_and_train_models():
    """Load the dataset, build LSGAN discriminator,
    generator, and adversarial models.
    Call the LSGAN train routine.
    """
    # load MNIST dataset
    (x_train, _), (_, _) = mnist.load_data()
    # reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train,
                         [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255
    model_name = "lsgan_mnist"
    # network parameters
    # the latent or z vector is 100-dim
    latent_size = 100
    input_shape = (image_size, image_size, 1)
    batch_size = 64
    lr = 2e-4
    decay = 6e-8
    train_steps = 40000
    # build discriminator model
    inputs = Input(shape=input_shape, name='discriminator_input')
    discriminator = gan.discriminator(inputs, activation=None)
    # [1] uses Adam, but discriminator easily 
    # converges with RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    # LSGAN uses MSE loss [2]
    discriminator.compile(loss='mse',
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary()
    # build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = gan.generator(inputs, image_size)
    generator.summary()
    # build adversarial model = generator + discriminator
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    # freeze the weights of discriminator 
    # during adversarial training
    discriminator.trainable = False
    adversarial = Model(inputs,
                        discriminator(generator(inputs)),
                        name=model_name)
    # LSGAN uses MSE loss [2]
    adversarial.compile(loss='mse',
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary()
    # train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps, model_name)
    gan.train(models, x_train, params) 
```

W
wizardforcel 已提交
642
“图 5.2.2”显示了使用 MNIST 数据集对 40,000 个训练步骤进行 LSGAN 训练后生成的样本:
W
wizardforcel 已提交
643

W
wizardforcel 已提交
644
![](img/B14853_05_07.png)
W
wizardforcel 已提交
645

W
wizardforcel 已提交
646
图 5.2.2:LSGAN 的示例输出与训练步骤
W
wizardforcel 已提交
647

W
wizardforcel 已提交
648
与上一章中 DCGAN 中的“图 4.2.1”相比,输出图像的感知质量更好。
W
wizardforcel 已提交
649 650 651 652 653 654 655 656 657 658 659

使用训练有素的生成器模型,通过运行以下命令来生成新的合成 MNIST 数字图像:

```py
python3 lsgan-mnist-5.2.1.py --generator=lsgan_mnist.h5 
```

在本节中,我们讨论了损失函数的另一种改进。 通过使用 MSE 或 L2,我们解决了训练 GAN 的稳定性和感知质量的双重问题。 在下一节中,提出了相对于 CGAN 的另一项改进,这已在上一章中进行了讨论。

# 3\. Auxiliary Classifier GAN (ACGAN)

W
wizardforcel 已提交
660
ACGAN 在原理上类似于我们在上一章中讨论的**条件 GAN****CGAN**)。 我们将比较 CGAN 和 ACGAN。 对于 CGAN 和 ACGAN,发生器输入均为噪声及其标签。 输出是属于输入类标签的伪图像。 对于 CGAN,判别器的输入是图像(假的或真实的)及其标签。 输出是图像真实的概率。 对于 ACGAN,判别器的输入是一幅图像,而输出是该图像是真实的且其类别是标签的概率。
W
wizardforcel 已提交
661

W
wizardforcel 已提交
662
“图 5.3.1”突出显示了生成器训练期间 CGAN 和 ACGAN 之间的区别:
W
wizardforcel 已提交
663

W
wizardforcel 已提交
664
![](img/B14853_05_08.png)
W
wizardforcel 已提交
665

W
wizardforcel 已提交
666
图 5.3.1:CGAN 与 ACGAN 生成器训练。 主要区别是判别器的输入和输出
W
wizardforcel 已提交
667 668 669

本质上,在 CGAN 中,我们向网络提供了边信息(标签)。 在 ACGAN 中,我们尝试使用辅助类解码器网络重建辅助信息。 ACGAN 理论认为,强制网络执行其他任务可以提高原始任务的性能。 在这种情况下,附加任务是图像分类。 原始任务是生成伪造图像。

W
wizardforcel 已提交
670
“表 5.3.1”显示了 ACGAN 损失函数与 CGAN 损失函数的比较:
W
wizardforcel 已提交
671 672 673 674 675

<colgroup><col> <col> <col></colgroup> 
| **网络** | **损失函数** | **编号** |
| CGAN | 

W
wizardforcel 已提交
676
![](img/B14853_05_102.png)
W
wizardforcel 已提交
677

W
wizardforcel 已提交
678
![](img/B14853_05_103.png)
W
wizardforcel 已提交
679 680 681 682

 | 4.3.14.3.2 |
| ACGAN | 

W
wizardforcel 已提交
683
![](img/B14853_05_104.png)
W
wizardforcel 已提交
684

W
wizardforcel 已提交
685
![](img/B14853_05_105.png)
W
wizardforcel 已提交
686 687 688 689 690

 | 5.3.15.3.2 |

表 5.3.1:CGAN 和 ACGAN 损失函数之间的比较

W
wizardforcel 已提交
691
ACGAN 损失函数与 CGAN 相同,除了附加的分类器损失函数。 除了从假图片中识别真实图像的原始任务之外,判别器的“公式 5.3.1”还具有对真假图像正确分类的附加任务。 生成器的“公式 5.3.2”意味着,除了尝试用伪造的图像来欺骗判别器(![](img/B14853_05_108.png))之外,它还要求判别器正确地对那些伪造的图像进行分类(![](img/B14853_05_109.png))。
W
wizardforcel 已提交
692

W
wizardforcel 已提交
693
从 CGAN 代码开始,仅修改判别器和训练功能以实现 ACGAN。 `gan.py`还提供了判别器和生成器构建器功能。 要查看判别器上所做的更改,清单 5.3.1 显示了构建器功能,其中突出显示了执行图像分类的辅助解码器网络和双输出。
W
wizardforcel 已提交
694

W
wizardforcel 已提交
695
“列表 5.3.1”:`gan.py`
W
wizardforcel 已提交
696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783

```py
def discriminator(inputs,
                  activation='sigmoid',
                  num_labels=None,
                  num_codes=None):
    """Build a Discriminator Model 
```

```py
 Stack of LeakyReLU-Conv2D to discriminate real from fake
    The network does not converge with BN so it is not used here
    unlike in [1]
    Arguments:
        inputs (Layer): Input layer of the discriminator (the image)
        activation (string): Name of output activation layer
        num_labels (int): Dimension of one-hot labels for ACGAN & InfoGAN
        num_codes (int): num_codes-dim Q network as output 
                    if StackedGAN or 2 Q networks if InfoGAN

    Returns:
        Model: Discriminator Model
    """
    kernel_size = 5
    layer_filters = [32, 64, 128, 256] 
```

```py
 x = inputs
    for filters in layer_filters:
        # first 3 convolution layers use strides = 2
        # last one uses strides = 1
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same')(x) 
```

```py
 x = Flatten()(x)
    # default output is probability that the image is real
    outputs = Dense(1)(x)
    if activation is not None:
        print(activation)
        outputs = Activation(activation)(outputs) 
```

```py
 if num_labels:
        # ACGAN and InfoGAN have 2nd output
        # 2nd output is 10-dim one-hot vector of label
        layer = Dense(layer_filters[-2])(x)
        labels = Dense(num_labels)(layer)
        labels = Activation('softmax', name='label')(labels)
        if num_codes is None:
            outputs = [outputs, labels]
        else:
            # InfoGAN have 3rd and 4th outputs
            # 3rd output is 1-dim continous Q of 1st c given x
            code1 = Dense(1)(layer)
            code1 = Activation('sigmoid', name='code1')(code1) 
```

```py
 # 4th output is 1-dim continuous Q of 2nd c given x
            code2 = Dense(1)(layer)
            code2 = Activation('sigmoid', name='code2')(code2) 
```

```py
 outputs = [outputs, labels, code1, code2]
    elif num_codes is not None:
        # StackedGAN Q0 output
        # z0_recon is reconstruction of z0 normal distribution
        z0_recon =  Dense(num_codes)(x)
        z0_recon = Activation('tanh', name='z0')(z0_recon)
        outputs = [outputs, z0_recon] 
```

```py
 return Model(inputs, outputs, name='discriminator') 
```

W
wizardforcel 已提交
784
然后通过调用以下命令来构建判别器:
W
wizardforcel 已提交
785 786 787 788 789

```py
discriminator = gan.discriminator(inputs, num_labels=num_labels) 
```

W
wizardforcel 已提交
790
生成器与 WGAN 和 LSGAN 中的生成器相同。 回想一下,在以下“列表 5.3.2”中显示了生成器生成器。 我们应该注意,“列表 5.3.1”和“5.3.2”与上一节中 WGAN 和 LSGAN 使用的生成器功能相同。 重点介绍了适用于 LSGAN 的部件。
W
wizardforcel 已提交
791

W
wizardforcel 已提交
792
“列表 5.3.2”:`gan.py`
W
wizardforcel 已提交
793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886

```py
def generator(inputs,
              image_size,
              activation='sigmoid',
              labels=None,
              codes=None):
    """Build a Generator Model 
```

```py
 Stack of BN-ReLU-Conv2DTranpose to generate fake images.
    Output activation is sigmoid instead of tanh in [1].
    Sigmoid converges easily. 
```

```py
 Arguments:
        inputs (Layer): Input layer of the generator (the z-vector)
        image_size (int): Target size of one side 
            (assuming square image)
        activation (string): Name of output activation layer
        labels (tensor): Input labels
        codes (list): 2-dim disentangled codes for InfoGAN 
```

```py
 Returns:
        Model: Generator Model
    """
    image_resize = image_size // 4
    # network parameters
    kernel_size = 5
    layer_filters = [128, 64, 32, 1] 
```

```py
 if labels is not None:
        if codes is None:
            # ACGAN labels
            # concatenate z noise vector and one-hot labels
            inputs = [inputs, labels]
        else:
            # infoGAN codes
            # concatenate z noise vector, 
            # one-hot labels and codes 1 & 2
            inputs = [inputs, labels] + codes
        x = concatenate(inputs, axis=1)
    elif codes is not None:
        # generator 0 of StackedGAN
        inputs = [inputs, codes]
        x = concatenate(inputs, axis=1)
    else:
        # default input is just 100-dim noise (z-code)
        x = inputs 
```

```py
 x = Dense(image_resize * image_resize * layer_filters[0])(x)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x) 
```

```py
 for filters in layer_filters:
        # first two convolution layers use strides = 2
        # the last two use strides = 1
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filters,
                            kernel_size=kernel_size,
                            strides=strides,
                            padding='same')(x) 
```

```py
 if activation is not None:
        x = Activation(activation)(x) 
```

```py
 # generator output is the synthesized image x
    return Model(inputs, x, name='generator') 
```

在 ACGAN 中,生成器实例化为:

```py
generator = gan.generator(inputs, image_size, labels=labels) 
```

W
wizardforcel 已提交
887
“图 5.3.2”显示了`tf.keras`中 ACGAN 的网络模型:
W
wizardforcel 已提交
888

W
wizardforcel 已提交
889
![](img/B14853_05_09.png)
W
wizardforcel 已提交
890

W
wizardforcel 已提交
891
图 5.3.2:ACGAN 的`tf.keras`模型
W
wizardforcel 已提交
892

W
wizardforcel 已提交
893
如“列表 5.3.3”所示,对判别器和对抗模型进行了修改,以适应判别器网络中的更改。 现在,我们有两个损失函数。 首先是原始的二进制交叉熵,用于训练判别器来估计输入图像为实的概率。
W
wizardforcel 已提交
894 895 896

第二个是图像分类器,用于预测类别标签。 输出是一个 10 维的单热向量。

W
wizardforcel 已提交
897
“列表 5.3.3”:`acgan-mnist-5.3.1.py`
W
wizardforcel 已提交
898

W
wizardforcel 已提交
899
重点介绍了在判别器和对抗网络中实现的更改:
W
wizardforcel 已提交
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997

```py
def build_and_train_models():
    """Load the dataset, build ACGAN discriminator,
    generator, and adversarial models.
    Call the ACGAN train routine.
    """
    # load MNIST dataset
    (x_train, y_train), (_, _) = mnist.load_data() 
```

```py
 # reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train,
                         [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255 
```

```py
 # train labels
    num_labels = len(np.unique(y_train))
    y_train = to_categorical(y_train) 
```

```py
 model_name = "acgan_mnist"
    # network parameters
    latent_size = 100
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)
    label_shape = (num_labels, ) 
```

```py
 # build discriminator Model
    inputs = Input(shape=input_shape,
                   name='discriminator_input')
    # call discriminator builder 
    # with 2 outputs, pred source and labels
    discriminator = gan.discriminator(inputs,
                                      num_labels=num_labels) 
```

```py
 # [1] uses Adam, but discriminator 
    # easily converges with RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    # 2 loss fuctions: 1) probability image is real
    # 2) class label of the image
    loss = ['binary_crossentropy', 'categorical_crossentropy']
    discriminator.compile(loss=loss,
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary() 
```

```py
 # build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    labels = Input(shape=label_shape, name='labels')
    # call generator builder with input labels
    generator = gan.generator(inputs,
                              image_size,
                              labels=labels)
    generator.summary() 
```

```py
 # build adversarial model = generator + discriminator
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    # freeze the weights of discriminator 
    # during adversarial training
    discriminator.trainable = False
    adversarial = Model([inputs, labels],
                        discriminator(generator([inputs, labels])),
                        name=model_name)
    # same 2 loss fuctions: 1) probability image is real
    # 2) class label of the image
    adversarial.compile(loss=loss,
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary() 
```

```py
 # train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    data = (x_train, y_train)
    params = (batch_size, latent_size, \
             train_steps, num_labels, model_name)
    train(models, data, params) 
```

W
wizardforcel 已提交
998
在“列表 5.3.4”中,我们重点介绍了训练例程中实现的更改。 将与 CGAN 代码进行比较的主要区别在于,必须在鉴别和对抗训练中提供输出标签。
W
wizardforcel 已提交
999

W
wizardforcel 已提交
1000
“列表 5.3.4”:`acgan-mnist-5.3.1.py`
W
wizardforcel 已提交
1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130

```py
def train(models, data, params):
    """Train the discriminator and adversarial Networks
    Alternately train discriminator and adversarial 
    networks by batch.
    Discriminator is trained first with real and fake 
    images and corresponding one-hot labels.
    Adversarial is trained next with fake images pretending 
    to be real and corresponding one-hot labels.
    Generate sample images per save_interval.
    # Arguments
        models (list): Generator, Discriminator,
            Adversarial models
        data (list): x_train, y_train data
        params (list): Network parameters
    """
    # the GAN models
    generator, discriminator, adversarial = models
    # images and their one-hot labels
    x_train, y_train = data
    # network parameters
    batch_size, latent_size, train_steps, num_labels, model_name \
            = params
    # the generator image is saved every 500 steps
    save_interval = 500
    # noise vector to see how the generator 
    # output evolves during training
    noise_input = np.random.uniform(-1.0,
                                    1.0,
                                    size=[16, latent_size])
    # class labels are 0, 1, 2, 3, 4, 5, 
    # 6, 7, 8, 9, 0, 1, 2, 3, 4, 5
    # the generator must produce these MNIST digits
    noise_label = np.eye(num_labels)[np.arange(0, 16) % num_labels]
    # number of elements in train dataset
    train_size = x_train.shape[0]
    print(model_name,
          "Labels for generated images: ",
          np.argmax(noise_label, axis=1)) 
```

```py
 for i in range(train_steps):
        # train the discriminator for 1 batch
        # 1 batch of real (label=1.0) and fake images (label=0.0)
        # randomly pick real images and 
        # corresponding labels from dataset 
        rand_indexes = np.random.randint(0,
                                         train_size,
                                         size=batch_size)
        real_images = x_train[rand_indexes]
        real_labels = y_train[rand_indexes]
        # generate fake images from noise using generator
        # generate noise using uniform distribution
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        # randomly pick one-hot labels
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,
                                                          batch_size)]
        # generate fake images
        fake_images = generator.predict([noise, fake_labels])
        # real + fake images = 1 batch of train data
        x = np.concatenate((real_images, fake_images))
        # real + fake labels = 1 batch of train data labels
        labels = np.concatenate((real_labels, fake_labels)) 
```

```py
 # label real and fake images
        # real images label is 1.0
        y = np.ones([2 * batch_size, 1])
        # fake images label is 0.0
        y[batch_size:, :] = 0
        # train discriminator network, log the loss and accuracy
        # ['loss', 'activation_1_loss', 
        # 'label_loss', 'activation_1_acc', 'label_acc']
        metrics  = discriminator.train_on_batch(x, [y, labels])
        fmt = "%d: [disc loss: %f, srcloss: %f,"
        fmt += "lblloss: %f, srcacc: %f, lblacc: %f]"
        log = fmt % (i, metrics[0], metrics[1], \
                metrics[2], metrics[3], metrics[4]) 
```

```py
 # train the adversarial network for 1 batch
        # 1 batch of fake images with label=1.0 and
        # corresponding one-hot label or class 
        # since the discriminator weights are frozen 
        # in adversarial network only the generator is trained
        # generate noise using uniform distribution
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        # randomly pick one-hot labels
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,
                                                          batch_size)]
        # label fake images as real
        y = np.ones([batch_size, 1])
        # train the adversarial network 
        # note that unlike in discriminator training, 
        # we do not save the fake images in a variable
        # the fake images go to the discriminator input 
        # of the adversarial for classification
        # log the loss and accuracy
        metrics  = adversarial.train_on_batch([noise, fake_labels],
                                              [y, fake_labels])
        fmt = "%s [advr loss: %f, srcloss: %f,"
        fmt += "lblloss: %f, srcacc: %f, lblacc: %f]"
        log = fmt % (log, metrics[0], metrics[1],\
                metrics[2], metrics[3], metrics[4])
        print(log)
        if (i + 1) % save_interval == 0:
            # plot generator images on a periodic basis
            gan.plot_images(generator,
                        noise_input=noise_input,
                        noise_label=noise_label,
                        show=False,
                        step=(i + 1),
                        model_name=model_name) 
```

```py
 # save the model after training the generator
    # the trained generator can be reloaded 
    # for future MNIST digit generation
    generator.save(model_name + ".h5") 
```

W
wizardforcel 已提交
1131
可以看出,与其他任务相比,与我们之前讨论的所有 GAN 相比,ACGAN 的性能显着提高。 ACGAN 训练是稳定的,如“图 5.3.3”的 ACGAN 示例输出的以下标签所示:
W
wizardforcel 已提交
1132 1133 1134 1135 1136 1137 1138 1139 1140 1141

```py
[0    1    2    3
 4    5    6    7
 8    9    0    1
 2    3    4    5] 
```

与 CGAN 不同,样本输出的外观在训练过程中变化不大。 MNIST 数字图像的感知质量也更好。

W
wizardforcel 已提交
1142
![](img/B14853_05_10.png)
W
wizardforcel 已提交
1143

W
wizardforcel 已提交
1144
图 5.3.3:ACGAN 根据标签的训练步骤生成的示例输出`[0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5]`
W
wizardforcel 已提交
1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157

使用训练有素的生成器模型,通过运行以下命令来生成新的合成 MNIST 数字图像:

```py
python3 acgan-mnist-5.3.1.py --generator=acgan_mnist.h5 
```

或者,也可以请求生成要生成的特定数字(例如 3):

```py
python3 acgan-mnist-5.3.1.py --generator=acgan_mnist.h5 --digit=3 
```

W
wizardforcel 已提交
1158
“图 5.3.4”显示了 CGAN 和 ACGAN 产生的每个 MNIST 数字的并排比较。 ACGAN 中的数字 2-6 比 CGAN 中的数字质量更好:
W
wizardforcel 已提交
1159

W
wizardforcel 已提交
1160
![](img/B14853_05_11.png)
W
wizardforcel 已提交
1161 1162 1163 1164 1165 1166 1167

图 5.3.4:以数字 0 到 9 为条件的 CGAN 和 ACGAN 输出的并排比较

与 WGAN 和 LSGAN 相似,ACGAN 通过微调的损失函数,对现有 GAN CGAN 进行了改进。 在接下来的章节中,我们将发现新的损失函数,这些函数将使 GAN 能够执行新的有用任务。

# 4。结论

W
wizardforcel 已提交
1168
在本章中,我们介绍了对原始 GAN 算法的各种改进,这些改进在上一章中首次介绍。 WGAN 提出了一种通过使用 EMD 或 Wasserstein 1 损失来提高训练稳定性的算法。 LSGAN 认为,与最小二乘损失不同,GANs 的原始交叉熵函数倾向于消失梯度。 LSGAN 提出了一种实现稳定训练和高质量输出的算法。 ACGAN 通过要求判别器在确定输入图像是假的还是真实的基础上执行分类任务,来令人信服地提高了 MNIST 数字有条件生成的质量。
W
wizardforcel 已提交
1169 1170 1171 1172 1173

在下一章中,我们将研究如何控制发电机输出的属性。 尽管 CGAN 和 ACGAN 可以指示要生成的期望数字,但我们尚未分析可以指定输出属性的 GAN。 例如,我们可能想要控制 MNIST 数字的书写风格,例如圆度,倾斜角度和厚度。 因此,目标是引入具有纠缠表示的 GAN,以控制生成器输出的特定属性。

# 5.参考

W
wizardforcel 已提交
1174 1175 1176 1177
1.  `Ian Goodfellow et al.: Generative Adversarial Nets. Advances in neural information processing systems, 2014 (http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf).`
1.  `Martin Arjovsky, Soumith Chintala, and Léon Bottou: Wasserstein GAN. arXiv preprint, 2017 (https://arxiv.org/pdf/1701.07875.pdf).`
1.  `Xudong Mao et al.: Least Squares Generative Adversarial Networks. 2017 IEEE International Conference on Computer Vision (ICCV). IEEE 2017 (http://openaccess.thecvf.com/content_ICCV_2017/papers/Mao_Least_Squares_Generative_ICCV_2017_paper.pdf).`
1.  `Augustus Odena, Christopher Olah, and Jonathon Shlens. Conditional Image Synthesis with Auxiliary Classifier GANs. ICML, 2017 (http://proceedings.mlr.press/v70/odena17a/odena17a.pdf).`