diff --git a/gan/README.md b/gan/README.md index 84e372c80061735f94e496a7e66e03c7f6b47679..700004f6905331455f00c99ce8a0aea531765d62 100644 --- a/gan/README.md +++ b/gan/README.md @@ -19,7 +19,7 @@ 图1. 生成模型概览

-为了解决上面这些模型的问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络。它相比于前面提到的方法,具有生成网络结构灵活,产生样本快,生成图像看起来更真实的优点。对抗式生成网络,也称为Generative Adversarial Network (GAN) \[[1](#参考文献)\]。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型,我们引入一个判别式神经元网络模型来构造优化目标函数。 +为了解决上面这些模型的问题,人们又提出了本章所要介绍的另一种生成模型,对抗式生成网络(Generative Adversarial Network - GAN)\[[1](#参考文献)\]。它相比于前面提到的方法,具有生成网络结构灵活,产生样本快,生成图像看起来更真实的优点。GAN的核心思想是,为了更好地训练一个生成式神经元网络模型,我们引入一个判别式神经元网络模型来构造优化目标函数。 ## 效果展示 @@ -33,7 +33,7 @@ ## 模型概览 ### 对抗式网络结构 -对抗式生成网络的基本结构是将一个已知概率分布的随机变量$z$,通过参数化的概率生成模型(通常是用一个神经网络模型来进行参数化),变换后得到一个生成的概率分布(图3中绿色的分布)。训练生成模型的过程就是调节生成模型的参数,使得生成的概率分布趋向于真实数据的概率分布(图3中蓝色的分布)。 +对抗式生成网络的基本结构是将一个已知概率分布的随机变量$z$,通过参数化的概率生成模型(通常是用一个神经网络模型来进行参数化),变换后得到一个生成的概率分布。训练生成模型的过程就是调节生成模型的参数,使得生成的概率分布趋向于真实数据的概率分布。 对抗式生成网络和之前的生成模型最大的创新就在于,用一个判别式神经网络来描述生成的概率分布和真实数据概率分布之间的差别。也就是说,我们用一个判别式模型 D 辅助构造优化目标函数,来训练一个生成式模型 G。G和D在训练时是处在相互对抗的角色下,G的目标是尽量生成和真实数据看起来相似的伪数据,从而使得D无法分别数据的真伪;而D的目标是能尽量分别出哪些是真实数据,哪些是G生成的伪数据。两者在竞争的条件下,能够相互提高各自的能力,最后收敛到一个均衡点:生成器生成的数据分布和真实数据分布完全一样,而判别器完全无法区分数据的真伪。 @@ -141,7 +141,7 @@ settings( ``` ### 模型结构 -本章里我们主要用到两种模型。一种是基本的GAN模型,主要由全连接层搭建,在gan_conf.py里面定义。另一种是DCGAN模型,主要由卷积层搭建,在gan_conf_image.py里面定义。 +本章里我们主要用到两种模型。一种是基本的GAN模型,主要由全连接层搭建,在`gan_conf.py`里面定义。另一种是DCGAN模型,主要由卷积层搭建,在`gan_conf_image.py`里面定义。 #### 对抗式网络基本框架 在文件`gan_conf.py`和`gan_conf_image.py`当中我们都定义了三个网络, **generator_training**, **discriminator_training** and **generator**. 和前文提到的模型结构的关系是:**discriminator_training** 是判别器,**generator** 是生成器,**generator_training** 是生成器加上判别器,这样构造的原因是因为训练生成器时需要用到判别器提供目标函数。这个对应关系在下面这段代码中定义: @@ -459,12 +459,14 @@ def discriminator(sample): ``` ## 训练模型 + ### 用Paddle API解析模型设置并创建trainer 为了能够训练在上面的模型配置文件中定义的网络,我们首先需要用Paddle API完成如下几个步骤: 1. 初始化Paddle环境 2. 解析设置 3. 由设置创造GradientMachine以及由GradientMachine创造trainer +4. 初始化并同步不同trainer里的参数 这几步分别由下面几段代码实现: @@ -492,32 +494,89 @@ generator_conf.model_config) # 由GradientMachine创造trainer dis_trainer = api.Trainer.create(dis_conf, dis_training_machine) gen_trainer = api.Trainer.create(gen_conf, gen_training_machine) -``` - +# trainer.startTrain()会初始化该trainer所对应的网络的参数 +dis_trainer.startTrain() +gen_trainer.startTrain() -为了能够平衡生成器和判别器之间的能力,我们依据它们各自的损失函数的大小来决定训练对象,即我们选择训练那个损失函数更大的网络。损失函数的值可以通过调用`GradientMachine`的`forward`方法来计算。 - -```python -def get_training_loss(training_machine, inputs): - outputs = api.Arguments.createArguments(0) - training_machine.forward(inputs, outputs, api.PASS_TEST) - loss = outputs.getSlotValue(0).copyToNumpyMat() - return numpy.mean(loss) +# 由于初始化的参数是随机的,所以需要同步不同网络之间的参数 +copy_shared_parameters(gen_training_machine, dis_training_machine) +copy_shared_parameters(gen_training_machine, generator_machine) ``` -每当训练完一个网络,我们需要和其他几个网络同步互相分享的参数值。下面的代码展示了其中一个例子: +### 用trainer来训练模型 +根据前面模型概览里面的介绍,对抗式生成网络需要轮流训练生成器和判别器。下面的代码描述了具体的训练流程: ```python -# 训练gen_training -gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) - -# 把gen_training中的参数同步到dis_training和generator当中 -copy_shared_parameters(gen_training_machine, -dis_training_machine) -copy_shared_parameters(gen_training_machine, generator_machine) +# 定义训练100个pass +for train_pass in xrange(100): + dis_trainer.startTrainPass() + gen_trainer.startTrainPass() + for i in xrange(num_iter): + noise = get_noise(batch_size, noise_dim) + # 准备真实数据样本 + data_batch_dis_pos = prepare_discriminator_data_batch_pos( + batch_size, data_np) + # 计算D关于真实数据样本的损失函数 + # 损失函数的值可以通过调用`GradientMachine`的`forward`方法来计算。 + dis_loss_pos = get_training_loss(dis_training_machine, + data_batch_dis_pos) + # 准备生成(伪)数据样本 + data_batch_dis_neg = prepare_discriminator_data_batch_neg( + generator_machine, batch_size, noise) + # 计算D关于生成数据样本的损失函数 + dis_loss_neg = get_training_loss(dis_training_machine, + data_batch_dis_neg) + + dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 + + data_batch_gen = prepare_generator_data_batch(batch_size, noise) + # 计算G关于生成数据样本的损失函数 + gen_loss = get_training_loss(gen_training_machine, data_batch_gen) + + if i % 100 == 0: + print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, + dis_loss_neg) + print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) + + # 为了能够平衡生成器和判别器之间的能力 + # 我们依据它们各自的损失函数的大小来决定训练对象 + # 即我们选择训练那个损失函数更大的网络 + # 但同时我们也限制不去连续训练一个网络太多次 + if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \ + ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss): + if curr_train == "dis": + curr_strike += 1 + else: + curr_train = "dis" + curr_strike = 1 + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) + dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) + # 每当训练完一个网络,我们需要和其他几个网络同步互相分享的参数值。 + copy_shared_parameters(dis_training_machine, + gen_training_machine) + + else: + if curr_train == "gen": + curr_strike += 1 + else: + curr_train = "gen" + curr_strike = 1 + gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) + # 每当训练完一个网络,我们需要和其他几个网络同步互相分享的参数值。 + copy_shared_parameters(gen_training_machine, + dis_training_machine) + copy_shared_parameters(gen_training_machine, generator_machine) + + dis_trainer.finishTrainPass() + gen_trainer.finishTrainPass() + # 在每个pass结束之后,保存生成数据 + fake_samples = get_fake_samples(generator_machine, batch_size, noise) + save_results(fake_samples, "./%s_samples/train_pass%s.png" % + (data_source, train_pass), data_source) ``` +### 训练脚本及结果 用MNIST手写数字图片训练对抗式生成网络可以用如下的命令。如果想用其他训练数据可以将参数`-d`改为uniform或者cifar。 ```bash @@ -543,29 +602,15 @@ I0105 17:16:37.172737 20517 TrainerInternal.cpp:165] Batch=100 samples=12800 Av ## 应用模型 -图片由训练好的生成器生成。以下的代码将随机向量输入到模型 G,通过向前传递得到生成的图片。 +以MNIST为例,在训练完成后,模型会保存在路径 mnist_params/pass-%05d 下,例如第100个pass的模型会保存在路径 output/pass-00099。可以用下面的命令,加载训练好的参数,来生成MNIST图片。生成的图片会保存在文件`generated_mnist_samples.png`里面。 -```python -# 噪音z是多维正态分布 -def get_noise(batch_size, noise_dim): - return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') - -def get_fake_samples(generator_machine, batch_size, noise): - gen_inputs = api.Arguments.createArguments(1) - gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) - gen_outputs = api.Arguments.createArguments(0) - generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST) - fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() - return fake_samples - -# 在每个pass的最后,保存生成的图片 -noise = get_noise(batch_size, noise_dim) -fake_samples = get_fake_samples(generator_machine, batch_size, noise) +```bash +$python gan_trainer.py -d mnist --use_gpu 1 --model_dir mnist_params/pass-00059 ``` ## 总结 -本章中,我们首先介绍了生成模型的概念,并简单回顾了几种常见的生成模型。对抗式生成网络是近期出现的一种全新的生成模型,它是由一个生成器和一个分类器通过相互对抗的方法来训练。我们着重介绍和用PaddlePaddle实现了两种常见的GAN模型:基本GAN模型和DCGAN模型。 +本章中,我们首先介绍了生成模型的概念,并简单回顾了几种常见的生成模型。对抗式生成网络是近期出现的一种全新的生成模型,它是由一个生成器和一个分类器通过相互对抗的方法来训练。我们着重介绍和用PaddlePaddle实现了两种常见的GAN模型:基本GAN模型和DCGAN模型。GAN模型还有许多相关的变形和扩展,其中一种conditional GAN模型就是把之前GAN模型的随机输入变量z变成有意义的输入,那么输出的图片就可以根据输入的信号来调节。这个方法有许多实际的应用,例如模型可以根据输入的低分辨率的图片生成相应的高分辨率图片,达到图片超分辨的效果\[[9](#参考文献)\];另一个应用场景是根据输入的文字来生成相对应的图片\[[10](#参考文献)\]。读者可以在本章搭建的GAN模型框架下,尝试实现各种最新研究中的扩展模型。 ## 参考文献 @@ -578,3 +623,5 @@ fake_samples = get_fake_samples(generator_machine, batch_size, noise) 6. Goodfellow I. [NIPS 2016 Tutorial: Generative Adversarial Networks] (https://arxiv.org/pdf/1701.00160v1.pdf) [C] arXiv preprint arXiv:1701.00160. 2016 7. Dumoulin V. and Visin F. [A guide to convolution arithmetic for deep learning] (https://arxiv.org/pdf/1603.07285v1.pdf). arXiv preprint arXiv:1603.07285. 2016 8. Kingma D., Ba J. [Adam: A method for stochastic optimization] (https://arxiv.org/pdf/1412.6980v8.pdf) arXiv preprint arXiv:1412.6980. 2014 +9. Ledig C, Theis L, Huszár F, et al. [Photo-realistic single image super-resolution using a generative adversarial network] (https://arxiv.org/pdf/1609.04802.pdf) arXiv preprint arXiv:1609.04802. 2016 +10. Reed S, Akata Z, Yan X, et al. [Generative adversarial text to image synthesis] (https://arxiv.org/pdf/1605.05396v2.pdf) arXiv preprint arXiv:1605.05396. 2016