机器学习模型是一个*函数*,具有输入和输出。在这里讨论中,我们将输入视为一个*i*维向量\(\vec{x}\),其中元素为\(x_{i}\)。然后我们可以将模型*M*表示为输入的矢量值函数:\(\vec{y} = \vec{M}(\vec{x})\)。(我们将 M 的输出值视为矢量,因为一般来说,模型可能具有任意数量的输出。)
机器学习模型是一个*函数*,具有输入和输出。在这里讨论中,我们将输入视为一个*i*维向量$\vec{x}$,其中元素为$x_{i}$。然后我们可以将模型*M*表示为输入的矢量值函数:$\vec{y} = \vec{M}(\vec{x})$。(我们将 M 的输出值视为矢量,因为一般来说,模型可能具有任意数量的输出。)
GAN 是一个框架,用于教授深度学习模型捕获训练数据分布,以便我们可以从相同分布生成新数据。GAN 是由 Ian Goodfellow 于 2014 年发明的,并首次在论文[生成对抗网络](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf)中描述。它们由两个不同的模型组成,一个*生成器*和一个*判别器*。生成器的任务是生成看起来像训练图像的“假”图像。判别器的任务是查看图像并输出它是来自真实训练图像还是来自生成器的假图像的概率。在训练过程中,生成器不断尝试欺骗判别器,生成越来越好的假图像,而判别器则努力成为更好的侦探,并正确分类真实和假图像。这个游戏的平衡是当生成器生成完美的假图像,看起来就像直接来自训练数据时,判别器总是以 50%的置信度猜测生成器的输出是真实的还是假的。
最后,现在我们已经定义了 GAN 框架的所有部分,我们可以开始训练。请注意,训练 GAN 有点像一种艺术形式,因为不正确的超参数设置会导致模式崩溃,而对出现问题的原因却没有太多解释。在这里,我们将紧密遵循[Goodfellow 的论文](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf)中的算法 1,同时遵循[ganhacks](https://github.com/soumith/ganhacks)中显示的一些最佳实践。换句话说,我们将“为真实和伪造图像构建不同的小批量”,并调整 G 的目标函数以最大化\(log(D(G(z)))\)。训练分为两个主要部分。第一部分更新鉴别器,第二部分更新生成器。
最后,现在我们已经定义了 GAN 框架的所有部分,我们可以开始训练。请注意,训练 GAN 有点像一种艺术形式,因为不正确的超参数设置会导致模式崩溃,而对出现问题的原因却没有太多解释。在这里,我们将紧密遵循[Goodfellow 的论文](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf)中的算法 1,同时遵循[ganhacks](https://github.com/soumith/ganhacks)中显示的一些最佳实践。换句话说,我们将“为真实和伪造图像构建不同的小批量”,并调整 G 的目标函数以最大化$log(D(G(z)))$。训练分为两个主要部分。第一部分更新鉴别器,第二部分更新生成器。
如原始论文所述,我们希望通过最小化\(log(1-D(G(z)))\)来训练生成器,以生成更好的伪造品。正如提到的,Goodfellow 指出,特别是在学习过程的早期,这并不能提供足够的梯度。为了解决这个问题,我们希望最大化\(log(D(G(z)))\)。在代码中,我们通过以下方式实现这一点:用鉴别器对第一部分的生成器输出进行分类,使用真实标签作为 GT 计算 G 的损失,通过反向传播计算 G 的梯度,最后使用优化器步骤更新 G 的参数。在损失函数中使用真实标签作为 GT 标签可能看起来有些反直觉,但这使我们可以使用`BCELoss`中的\(log(x)\)部分(而不是\(log(1-x)\)部分),这正是我们想要的。
如原始论文所述,我们希望通过最小化$log(1-D(G(z)))$来训练生成器,以生成更好的伪造品。正如提到的,Goodfellow 指出,特别是在学习过程的早期,这并不能提供足够的梯度。为了解决这个问题,我们希望最大化$log(D(G(z)))$。在代码中,我们通过以下方式实现这一点:用鉴别器对第一部分的生成器输出进行分类,使用真实标签作为 GT 计算 G 的损失,通过反向传播计算 G 的梯度,最后使用优化器步骤更新 G 的参数。在损失函数中使用真实标签作为 GT 标签可能看起来有些反直觉,但这使我们可以使用`BCELoss`中的$log(x)$部分(而不是$log(1-x)$部分),这正是我们想要的。
最后,我们将进行一些统计报告,并在每个时代结束时将我们的 fixed_noise 批次通过生成器,以直观地跟踪 G 的训练进度。报告的训练统计数据为: