未验证 提交 e3dc5ab4 编写于 作者: W whs 提交者: GitHub

Add dc_gan and fix readme. (#1275)

* Add dc_gan and fix readme.

* Fix readme style.

* Fix readme.

* Fix readme.
上级 d5a4588a
...@@ -29,20 +29,29 @@ env CUDA_VISIBLE_DEVICES=0 python c_gan.py --output="./result" ...@@ -29,20 +29,29 @@ env CUDA_VISIBLE_DEVICES=0 python c_gan.py --output="./result"
执行`python c_gan.py --help`可查看更多使用方式和参数详细说明。 执行`python c_gan.py --help`可查看更多使用方式和参数详细说明。
图1为conditionalGAN训练损失示意图,其中横坐标轴为训练轮数,纵轴为在训练集上的损失。其中,'G_loss'和'D_loss'分别为生成网络和判别器网络的训练损失。 图1为conditionalGAN训练损失示意图,其中横坐标轴为训练轮数,纵轴为在训练集上的损失。其中,'G_loss'和'D_loss'分别为生成网络和判别器网络的训练损失。conditionalGAN训练19轮的模型预测效果如图2所示.
<p align="center"> <p style="background-color: #fff; align: center">
<img src="images/conditionalGAN_loss.png" width="620" hspace='10'/> <br/> <table>
<strong>图 1</strong> <tbody>
</p> <tr>
<td>
<img src="images/conditionalGAN_loss.png" width="400" hspace='10'/>
</td>
conditionalGAN训练19轮的模型预测效果如图2所示: <td>
<img src="images/conditionalGAN_demo.png" width="300" hspace='10'/>
<p align="center"> </td>
<img src="images/conditionalGAN_demo.png" width="620" hspace='10'/> <br/> </tr>
<strong>图 2</strong> <tr>
<td>
<strong align="center">图 1</strong>
</td>
<td>
<strong align="center">图 2</strong>
</td>
</tr>
</tbody>
</table>
</p> </p>
...@@ -59,9 +68,9 @@ env CUDA_VISIBLE_DEVICES=0 python dc_gan.py --output="./result" ...@@ -59,9 +68,9 @@ env CUDA_VISIBLE_DEVICES=0 python dc_gan.py --output="./result"
执行`python dc_gan.py --help`可查看更多使用方式和参数详细说明。 执行`python dc_gan.py --help`可查看更多使用方式和参数详细说明。
DCGAN训10轮的模型预测效果如图3所示: DCGAN训10轮的模型预测效果如图3所示:
<p align="center"> <p align="center">
<img src="images/DCGAN_demo.png" width="620" hspace='10'/> <br/> <img src="images/DCGAN_demo.png" width="300" hspace='10'/> <br/>
<strong>图 3</strong> <strong>图 3</strong>
</p> </p>
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import argparse
import functools
import matplotlib
import numpy as np
import paddle
import paddle.fluid as fluid
from utility import get_parent_function_name, plot, check, add_arguments, print_arguments
from network import G, D
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
NOISE_SIZE = 100
LEARNING_RATE = 2e-4
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 121, "Minibatch size.")
add_arg('epoch', int, 20, "The number of epoched to be trained.")
add_arg('output', str, "./output", "The directory the model and the test result to be saved to.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
# yapf: enable
def loss(x, label):
return fluid.layers.mean(x * (label - 0.5))
def train(args):
d_program = fluid.Program()
dg_program = fluid.Program()
with fluid.program_guard(d_program):
img = fluid.layers.data(name='img', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='float32')
d_logit = D(img)
d_loss = loss(d_logit, label)
with fluid.program_guard(dg_program):
noise = fluid.layers.data(
name='noise', shape=[NOISE_SIZE], dtype='float32')
g_img = G(x=noise)
g_program = dg_program.clone()
g_program_test = dg_program.clone(for_test=True)
dg_logit = D(g_img)
dg_loss = loss(dg_logit, 1)
opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)
opt.minimize(loss=d_loss)
parameters = [p.name for p in g_program.global_block().all_parameters()]
opt.minimize(loss=dg_loss, parameter_list=parameters)
exe = fluid.Executor(fluid.CPUPlace())
if args.use_gpu:
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=60000),
batch_size=args.batch_size)
NUM_TRAIN_TIMES_OF_DG = 2
const_n = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32')
for pass_id in range(args.epoch):
for batch_id, data in enumerate(train_reader()):
if len(data) != args.batch_size:
continue
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32')
real_image = np.array(map(lambda x: x[0], data)).reshape(
-1, 784).astype('float32')
real_labels = np.ones(
shape=[real_image.shape[0], 1], dtype='float32')
fake_labels = np.zeros(
shape=[real_image.shape[0], 1], dtype='float32')
total_label = np.concatenate([real_labels, fake_labels])
generated_image = exe.run(g_program,
feed={'noise': noise_data},
fetch_list={g_img})[0]
total_images = np.concatenate([real_image, generated_image])
d_loss_1 = exe.run(d_program,
feed={
'img': generated_image,
'label': fake_labels,
},
fetch_list={d_loss})
d_loss_2 = exe.run(d_program,
feed={
'img': real_image,
'label': real_labels,
},
fetch_list={d_loss})
d_loss_np = [d_loss_1[0][0], d_loss_2[0][0]]
for _ in xrange(NUM_TRAIN_TIMES_OF_DG):
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32')
dg_loss_np = exe.run(dg_program,
feed={'noise': noise_data},
fetch_list={dg_loss})[0]
if batch_id % 10 == 0:
if not os.path.exists(args.output):
os.makedirs(args.output)
# generate image each batch
generated_images = exe.run(g_program_test,
feed={'noise': const_n},
fetch_list={g_img})[0]
total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images)
msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}".format(
pass_id, batch_id, d_loss_np, dg_loss_np,
check(generated_images))
print(msg)
plt.title(msg)
plt.savefig(
'{}/{:04d}_{:04d}.png'.format(args.output, pass_id,
batch_id),
bbox_inches='tight')
plt.close(fig)
if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args)
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册