未验证 提交 d5a4588a 编写于 作者: W whs 提交者: GitHub

Add conditional GAN (#1270)

* Add conditionalGAN

* Fix readme.

* Fix image name in readme.

* Add DCGAN.

* Rename conditional_gan to c_gan.

* Move c_gan and cycle_gan to gan/
上级 ee694b68

运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
## 代码结构
```
├── network.py # 定义基础生成网络和判别网络。
├── utility.py # 定义通用工具方法。
├── dc_gan.py # DCGAN训练脚本。
└── c_gan.py # conditionalGAN训练脚本。
```
## 简介
TODO
## 数据准备
本教程使用 mnist 数据集来进行模型的训练测试工作,该数据集通过`paddle.dataset`模块自动下载到本地。
## 训练测试conditianalGAN
在GPU单卡上训练conditionalGAN:
```
env CUDA_VISIBLE_DEVICES=0 python c_gan.py --output="./result"
```
训练过程中,每隔固定的训练轮数,会取一个batch的数据进行测试,测试结果以图片的形式保存至`--output`选项指定的路径。
执行`python c_gan.py --help`可查看更多使用方式和参数详细说明。
图1为conditionalGAN训练损失示意图,其中横坐标轴为训练轮数,纵轴为在训练集上的损失。其中,'G_loss'和'D_loss'分别为生成网络和判别器网络的训练损失。
<p align="center">
<img src="images/conditionalGAN_loss.png" width="620" hspace='10'/> <br/>
<strong>图 1</strong>
</p>
conditionalGAN训练19轮的模型预测效果如图2所示:
<p align="center">
<img src="images/conditionalGAN_demo.png" width="620" hspace='10'/> <br/>
<strong>图 2</strong>
</p>
## 训练测试DCGAN
在GPU单卡上训练DCGAN:
```
env CUDA_VISIBLE_DEVICES=0 python dc_gan.py --output="./result"
```
训练过程中,每隔固定的训练轮数,会取一个batch的数据进行测试,测试结果以图片的形式保存至`--output`选项指定的路径。
执行`python dc_gan.py --help`可查看更多使用方式和参数详细说明。
DCGAN训10轮的模型预测效果如图3所示:
<p align="center">
<img src="images/DCGAN_demo.png" width="620" hspace='10'/> <br/>
<strong>图 3</strong>
</p>
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import argparse
import functools
import matplotlib
import numpy as np
import paddle
import paddle.fluid as fluid
from utility import get_parent_function_name, plot, check, add_arguments, print_arguments
from network import G_cond, D_cond
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
NOISE_SIZE = 100
LEARNING_RATE = 2e-4
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 121, "Minibatch size.")
add_arg('epoch', int, 20, "The number of epoched to be trained.")
add_arg('output', str, "./output", "The directory the model and the test result to be saved to.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
# yapf: enable
def loss(x, label):
return fluid.layers.mean(x * (label - 0.5))
def train(args):
d_program = fluid.Program()
dg_program = fluid.Program()
with fluid.program_guard(d_program):
conditions = fluid.layers.data(
name='conditions', shape=[1], dtype='float32')
img = fluid.layers.data(name='img', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='float32')
d_logit = D_cond(img, conditions)
d_loss = loss(d_logit, label)
with fluid.program_guard(dg_program):
conditions = fluid.layers.data(
name='conditions', shape=[1], dtype='float32')
noise = fluid.layers.data(
name='noise', shape=[NOISE_SIZE], dtype='float32')
g_img = G_cond(z=noise, y=conditions)
g_program = dg_program.clone()
g_program_test = dg_program.clone(for_test=True)
dg_logit = D_cond(g_img, conditions)
dg_loss = loss(dg_logit, 1)
opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)
opt.minimize(loss=d_loss)
parameters = [p.name for p in g_program.global_block().all_parameters()]
opt.minimize(loss=dg_loss, parameter_list=parameters)
exe = fluid.Executor(fluid.CPUPlace())
if args.use_gpu:
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=60000),
batch_size=args.batch_size)
NUM_TRAIN_TIMES_OF_DG = 2
const_n = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32')
for pass_id in range(args.epoch):
for batch_id, data in enumerate(train_reader()):
if len(data) != args.batch_size:
continue
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32')
real_image = np.array(map(lambda x: x[0], data)).reshape(
-1, 784).astype('float32')
conditions_data = np.array([x[1] for x in data]).reshape(
[-1, 1]).astype("float32")
real_labels = np.ones(
shape=[real_image.shape[0], 1], dtype='float32')
fake_labels = np.zeros(
shape=[real_image.shape[0], 1], dtype='float32')
total_label = np.concatenate([real_labels, fake_labels])
generated_image = exe.run(
g_program,
feed={'noise': noise_data,
'conditions': conditions_data},
fetch_list={g_img})[0]
total_images = np.concatenate([real_image, generated_image])
d_loss_1 = exe.run(d_program,
feed={
'img': generated_image,
'label': fake_labels,
'conditions': conditions_data
},
fetch_list={d_loss})
d_loss_2 = exe.run(d_program,
feed={
'img': real_image,
'label': real_labels,
'conditions': conditions_data
},
fetch_list={d_loss})
d_loss_np = [d_loss_1[0][0], d_loss_2[0][0]]
for _ in xrange(NUM_TRAIN_TIMES_OF_DG):
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, NOISE_SIZE]).astype('float32')
dg_loss_np = exe.run(
dg_program,
feed={'noise': noise_data,
'conditions': conditions_data},
fetch_list={dg_loss})[0]
if batch_id % 10 == 0:
if not os.path.exists(args.output):
os.makedirs(args.output)
# generate image each batch
generated_images = exe.run(
g_program_test,
feed={'noise': const_n,
'conditions': conditions_data},
fetch_list={g_img})[0]
total_images = np.concatenate([real_image, generated_images])
fig = plot(total_images)
msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}".format(
pass_id, batch_id, d_loss_np, dg_loss_np,
check(generated_images))
print(msg)
plt.title(msg)
plt.savefig(
'{}/{:04d}_{:04d}.png'.format(args.output, pass_id,
batch_id),
bbox_inches='tight')
plt.close(fig)
if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args)
train(args)
import paddle
import paddle.fluid as fluid
from utility import get_parent_function_name
gf_dim = 64
df_dim = 64
gfc_dim = 1024 * 2
dfc_dim = 1024
img_dim = 28
c_dim = 3
y_dim = 1
output_height = 28
output_width = 28
def bn(x, name=None, act='relu'):
if name is None:
name = get_parent_function_name()
#return fluid.layers.leaky_relu(x)
return fluid.layers.batch_norm(
x,
param_attr=name + '1',
bias_attr=name + '2',
moving_mean_name=name + '3',
moving_variance_name=name + '4',
name=name,
act=act)
def conv(x, num_filters, name=None, act=None):
if name is None:
name = get_parent_function_name()
return fluid.nets.simple_img_conv_pool(
input=x,
filter_size=5,
num_filters=num_filters,
pool_size=2,
pool_stride=2,
param_attr=name + 'w',
bias_attr=name + 'b',
act=act)
def fc(x, num_filters, name=None, act=None):
if name is None:
name = get_parent_function_name()
return fluid.layers.fc(input=x,
size=num_filters,
act=act,
param_attr=name + 'w',
bias_attr=name + 'b')
def deconv(x,
num_filters,
name=None,
filter_size=5,
stride=2,
dilation=1,
padding=2,
output_size=None,
act=None):
if name is None:
name = get_parent_function_name()
return fluid.layers.conv2d_transpose(
input=x,
param_attr=name + 'w',
bias_attr=name + 'b',
num_filters=num_filters,
output_size=output_size,
filter_size=filter_size,
stride=stride,
dilation=dilation,
padding=padding,
act=act)
def conv_cond_concat(x, y):
"""Concatenate conditioning vector on feature map axis."""
ones = fluid.layers.fill_constant_batch_size_like(
x, [-1, y.shape[1], x.shape[2], x.shape[3]], "float32", 1.0)
return fluid.layers.concat([x, ones * y], 1)
def D_cond(image, y):
image = fluid.layers.reshape(x=image, shape=[-1, 1, 28, 28])
yb = fluid.layers.reshape(y, [-1, y_dim, 1, 1])
x = conv_cond_concat(image, yb)
h0 = conv(x, c_dim + y_dim, act="leaky_relu")
h0 = conv_cond_concat(h0, yb)
h1 = bn(conv(h0, df_dim + y_dim), act="leaky_relu")
h1 = fluid.layers.flatten(h1, axis=1)
h1 = fluid.layers.concat([h1, y], 1)
h2 = bn(fc(h1, dfc_dim), act='leaky_relu')
h2 = fluid.layers.concat([h2, y], 1)
h3 = fc(h2, 1)
return h3
def G_cond(z, y):
s_h, s_w = output_height, output_width
s_h2, s_h4 = int(s_h / 2), int(s_h / 4)
s_w2, s_w4 = int(s_w / 2), int(s_w / 4)
yb = fluid.layers.reshape(y, [-1, y_dim, 1, 1]) #NCHW
z = fluid.layers.concat([z, y], 1)
h0 = bn(fc(z, gfc_dim / 2), act='relu')
h0 = fluid.layers.concat([h0, y], 1)
h1 = bn(fc(h0, gf_dim * 2 * s_h4 * s_w4), act='relu')
h1 = fluid.layers.reshape(h1, [-1, gf_dim * 2, s_h4, s_w4])
h1 = conv_cond_concat(h1, yb)
h2 = bn(deconv(h1, gf_dim * 2, output_size=[s_h2, s_w2]), act='relu')
h2 = conv_cond_concat(h2, yb)
h3 = deconv(h2, 1, output_size=[s_h, s_w], act='tanh')
return fluid.layers.reshape(h3, shape=[-1, s_h * s_w])
def D(x):
x = fluid.layers.reshape(x=x, shape=[-1, 1, 28, 28])
x = conv(x, df_dim, act='leaky_relu')
x = bn(conv(x, df_dim * 2), act='leaky_relu')
x = bn(fc(x, dfc_dim), act='leaky_relu')
x = fc(x, 1, act=None)
return x
def G(x):
x = bn(fc(x, gfc_dim))
x = bn(fc(x, gf_dim * 2 * img_dim / 4 * img_dim / 4))
x = fluid.layers.reshape(x, [-1, gf_dim * 2, img_dim / 4, img_dim / 4])
x = deconv(x, gf_dim * 2, act='relu', output_size=[14, 14])
x = deconv(x, 1, filter_size=5, padding=2, act='tanh', output_size=[28, 28])
x = fluid.layers.reshape(x, shape=[-1, 28 * 28])
return x
import math
import distutils.util
import numpy as np
import inspect
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
img_dim = 28
def get_parent_function_name():
return inspect.stack()[2][3] + '.' + inspect.stack()[1][3] + '.' + str(
inspect.stack()[2][2]) + '.'
def plot(gen_data):
pad_dim = 1
paded = pad_dim + img_dim
gen_data = gen_data.reshape(gen_data.shape[0], img_dim, img_dim)
n = int(math.ceil(math.sqrt(gen_data.shape[0])))
gen_data = (np.pad(
gen_data, [[0, n * n - gen_data.shape[0]], [pad_dim, 0], [pad_dim, 0]],
'constant').reshape((n, n, paded, paded)).transpose((0, 2, 1, 3))
.reshape((n * paded, n * paded)))
fig = plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(gen_data, cmap='Greys_r', vmin=-1, vmax=1)
return fig
def check(a):
a = np.sort(np.array(a).flatten())
return [
np.average(a), np.min(a), np.max(a), a[int(len(a) * 0.25)],
a[int(len(a) * 0.75)]
]
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).iteritems()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册