diff --git a/fluid/gan/c_gan/README.md b/fluid/gan/c_gan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ea61dfb1623696947cdf3c5e405aeea2fbecb212 --- /dev/null +++ b/fluid/gan/c_gan/README.md @@ -0,0 +1,67 @@ + + +运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。 + +## 代码结构 +``` +├── network.py # 定义基础生成网络和判别网络。 +├── utility.py # 定义通用工具方法。 +├── dc_gan.py # DCGAN训练脚本。 +└── c_gan.py # conditionalGAN训练脚本。 +``` + +## 简介 +TODO + +## 数据准备 + +本教程使用 mnist 数据集来进行模型的训练测试工作,该数据集通过`paddle.dataset`模块自动下载到本地。 + +## 训练测试conditianalGAN + +在GPU单卡上训练conditionalGAN: + +``` +env CUDA_VISIBLE_DEVICES=0 python c_gan.py --output="./result" +``` + +训练过程中,每隔固定的训练轮数,会取一个batch的数据进行测试,测试结果以图片的形式保存至`--output`选项指定的路径。 + +执行`python c_gan.py --help`可查看更多使用方式和参数详细说明。 + +图1为conditionalGAN训练损失示意图,其中横坐标轴为训练轮数,纵轴为在训练集上的损失。其中,'G_loss'和'D_loss'分别为生成网络和判别器网络的训练损失。 + +

+
+图 1 +

+ + + +conditionalGAN训练19轮的模型预测效果如图2所示: + +

+
+图 2 +

+ + +## 训练测试DCGAN + +在GPU单卡上训练DCGAN: + +``` +env CUDA_VISIBLE_DEVICES=0 python dc_gan.py --output="./result" +``` + +训练过程中,每隔固定的训练轮数,会取一个batch的数据进行测试,测试结果以图片的形式保存至`--output`选项指定的路径。 + +执行`python dc_gan.py --help`可查看更多使用方式和参数详细说明。 + + +DCGAN训10轮的模型预测效果如图3所示: + +

+
+图 3 +

diff --git a/fluid/gan/c_gan/c_gan.py b/fluid/gan/c_gan/c_gan.py new file mode 100644 index 0000000000000000000000000000000000000000..10e899e349253352fa42cbf2571e5ad4be31a672 --- /dev/null +++ b/fluid/gan/c_gan/c_gan.py @@ -0,0 +1,171 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import argparse +import functools +import matplotlib +import numpy as np +import paddle +import paddle.fluid as fluid +from utility import get_parent_function_name, plot, check, add_arguments, print_arguments +from network import G_cond, D_cond +matplotlib.use('agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + +NOISE_SIZE = 100 +LEARNING_RATE = 2e-4 + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 121, "Minibatch size.") +add_arg('epoch', int, 20, "The number of epoched to be trained.") +add_arg('output', str, "./output", "The directory the model and the test result to be saved to.") +add_arg('use_gpu', bool, True, "Whether to use GPU to train.") +# yapf: enable + + +def loss(x, label): + return fluid.layers.mean(x * (label - 0.5)) + + +def train(args): + + d_program = fluid.Program() + dg_program = fluid.Program() + + with fluid.program_guard(d_program): + conditions = fluid.layers.data( + name='conditions', shape=[1], dtype='float32') + img = fluid.layers.data(name='img', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='float32') + d_logit = D_cond(img, conditions) + d_loss = loss(d_logit, label) + + with fluid.program_guard(dg_program): + conditions = fluid.layers.data( + name='conditions', shape=[1], dtype='float32') + noise = fluid.layers.data( + name='noise', shape=[NOISE_SIZE], dtype='float32') + g_img = G_cond(z=noise, y=conditions) + + g_program = dg_program.clone() + g_program_test = dg_program.clone(for_test=True) + + dg_logit = D_cond(g_img, conditions) + dg_loss = loss(dg_logit, 1) + + opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE) + + opt.minimize(loss=d_loss) + parameters = [p.name for p in g_program.global_block().all_parameters()] + + opt.minimize(loss=dg_loss, parameter_list=parameters) + + exe = fluid.Executor(fluid.CPUPlace()) + if args.use_gpu: + exe = fluid.Executor(fluid.CUDAPlace(0)) + exe.run(fluid.default_startup_program()) + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=60000), + batch_size=args.batch_size) + + NUM_TRAIN_TIMES_OF_DG = 2 + const_n = np.random.uniform( + low=-1.0, high=1.0, + size=[args.batch_size, NOISE_SIZE]).astype('float32') + for pass_id in range(args.epoch): + for batch_id, data in enumerate(train_reader()): + if len(data) != args.batch_size: + continue + noise_data = np.random.uniform( + low=-1.0, high=1.0, + size=[args.batch_size, NOISE_SIZE]).astype('float32') + real_image = np.array(map(lambda x: x[0], data)).reshape( + -1, 784).astype('float32') + conditions_data = np.array([x[1] for x in data]).reshape( + [-1, 1]).astype("float32") + real_labels = np.ones( + shape=[real_image.shape[0], 1], dtype='float32') + fake_labels = np.zeros( + shape=[real_image.shape[0], 1], dtype='float32') + total_label = np.concatenate([real_labels, fake_labels]) + + generated_image = exe.run( + g_program, + feed={'noise': noise_data, + 'conditions': conditions_data}, + fetch_list={g_img})[0] + + total_images = np.concatenate([real_image, generated_image]) + + d_loss_1 = exe.run(d_program, + feed={ + 'img': generated_image, + 'label': fake_labels, + 'conditions': conditions_data + }, + fetch_list={d_loss}) + + d_loss_2 = exe.run(d_program, + feed={ + 'img': real_image, + 'label': real_labels, + 'conditions': conditions_data + }, + fetch_list={d_loss}) + + d_loss_np = [d_loss_1[0][0], d_loss_2[0][0]] + + for _ in xrange(NUM_TRAIN_TIMES_OF_DG): + noise_data = np.random.uniform( + low=-1.0, high=1.0, + size=[args.batch_size, NOISE_SIZE]).astype('float32') + dg_loss_np = exe.run( + dg_program, + feed={'noise': noise_data, + 'conditions': conditions_data}, + fetch_list={dg_loss})[0] + if batch_id % 10 == 0: + if not os.path.exists(args.output): + os.makedirs(args.output) + # generate image each batch + generated_images = exe.run( + g_program_test, + feed={'noise': const_n, + 'conditions': conditions_data}, + fetch_list={g_img})[0] + total_images = np.concatenate([real_image, generated_images]) + fig = plot(total_images) + msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}".format( + pass_id, batch_id, d_loss_np, dg_loss_np, + check(generated_images)) + print(msg) + plt.title(msg) + plt.savefig( + '{}/{:04d}_{:04d}.png'.format(args.output, pass_id, + batch_id), + bbox_inches='tight') + plt.close(fig) + + +if __name__ == "__main__": + args = parser.parse_args() + print_arguments(args) + train(args) diff --git a/fluid/gan/c_gan/images/DCGAN_demo.png b/fluid/gan/c_gan/images/DCGAN_demo.png new file mode 100644 index 0000000000000000000000000000000000000000..5c6a4f0b55e07826256801d2e2b7754ee0d5edef Binary files /dev/null and b/fluid/gan/c_gan/images/DCGAN_demo.png differ diff --git a/fluid/gan/c_gan/images/conditionalGAN_demo.png b/fluid/gan/c_gan/images/conditionalGAN_demo.png new file mode 100644 index 0000000000000000000000000000000000000000..234599e9e3a98e1872714375e96f288df16b9a50 Binary files /dev/null and b/fluid/gan/c_gan/images/conditionalGAN_demo.png differ diff --git a/fluid/gan/c_gan/images/conditionalGAN_loss.png b/fluid/gan/c_gan/images/conditionalGAN_loss.png new file mode 100644 index 0000000000000000000000000000000000000000..01af0a76c47d02a9e760e00d824422343d6f26bc Binary files /dev/null and b/fluid/gan/c_gan/images/conditionalGAN_loss.png differ diff --git a/fluid/gan/c_gan/network.py b/fluid/gan/c_gan/network.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0dc073830e8dbeb28c7e9f96ae4795a0ab7fa8 --- /dev/null +++ b/fluid/gan/c_gan/network.py @@ -0,0 +1,142 @@ +import paddle +import paddle.fluid as fluid +from utility import get_parent_function_name + +gf_dim = 64 +df_dim = 64 +gfc_dim = 1024 * 2 +dfc_dim = 1024 +img_dim = 28 + +c_dim = 3 +y_dim = 1 +output_height = 28 +output_width = 28 + + +def bn(x, name=None, act='relu'): + if name is None: + name = get_parent_function_name() + #return fluid.layers.leaky_relu(x) + return fluid.layers.batch_norm( + x, + param_attr=name + '1', + bias_attr=name + '2', + moving_mean_name=name + '3', + moving_variance_name=name + '4', + name=name, + act=act) + + +def conv(x, num_filters, name=None, act=None): + if name is None: + name = get_parent_function_name() + return fluid.nets.simple_img_conv_pool( + input=x, + filter_size=5, + num_filters=num_filters, + pool_size=2, + pool_stride=2, + param_attr=name + 'w', + bias_attr=name + 'b', + act=act) + + +def fc(x, num_filters, name=None, act=None): + if name is None: + name = get_parent_function_name() + return fluid.layers.fc(input=x, + size=num_filters, + act=act, + param_attr=name + 'w', + bias_attr=name + 'b') + + +def deconv(x, + num_filters, + name=None, + filter_size=5, + stride=2, + dilation=1, + padding=2, + output_size=None, + act=None): + if name is None: + name = get_parent_function_name() + return fluid.layers.conv2d_transpose( + input=x, + param_attr=name + 'w', + bias_attr=name + 'b', + num_filters=num_filters, + output_size=output_size, + filter_size=filter_size, + stride=stride, + dilation=dilation, + padding=padding, + act=act) + + +def conv_cond_concat(x, y): + """Concatenate conditioning vector on feature map axis.""" + ones = fluid.layers.fill_constant_batch_size_like( + x, [-1, y.shape[1], x.shape[2], x.shape[3]], "float32", 1.0) + return fluid.layers.concat([x, ones * y], 1) + + +def D_cond(image, y): + image = fluid.layers.reshape(x=image, shape=[-1, 1, 28, 28]) + yb = fluid.layers.reshape(y, [-1, y_dim, 1, 1]) + x = conv_cond_concat(image, yb) + + h0 = conv(x, c_dim + y_dim, act="leaky_relu") + h0 = conv_cond_concat(h0, yb) + h1 = bn(conv(h0, df_dim + y_dim), act="leaky_relu") + h1 = fluid.layers.flatten(h1, axis=1) + + h1 = fluid.layers.concat([h1, y], 1) + + h2 = bn(fc(h1, dfc_dim), act='leaky_relu') + h2 = fluid.layers.concat([h2, y], 1) + + h3 = fc(h2, 1) + return h3 + + +def G_cond(z, y): + s_h, s_w = output_height, output_width + s_h2, s_h4 = int(s_h / 2), int(s_h / 4) + s_w2, s_w4 = int(s_w / 2), int(s_w / 4) + + yb = fluid.layers.reshape(y, [-1, y_dim, 1, 1]) #NCHW + + z = fluid.layers.concat([z, y], 1) + h0 = bn(fc(z, gfc_dim / 2), act='relu') + h0 = fluid.layers.concat([h0, y], 1) + + h1 = bn(fc(h0, gf_dim * 2 * s_h4 * s_w4), act='relu') + h1 = fluid.layers.reshape(h1, [-1, gf_dim * 2, s_h4, s_w4]) + + h1 = conv_cond_concat(h1, yb) + h2 = bn(deconv(h1, gf_dim * 2, output_size=[s_h2, s_w2]), act='relu') + h2 = conv_cond_concat(h2, yb) + h3 = deconv(h2, 1, output_size=[s_h, s_w], act='tanh') + return fluid.layers.reshape(h3, shape=[-1, s_h * s_w]) + + +def D(x): + x = fluid.layers.reshape(x=x, shape=[-1, 1, 28, 28]) + x = conv(x, df_dim, act='leaky_relu') + x = bn(conv(x, df_dim * 2), act='leaky_relu') + x = bn(fc(x, dfc_dim), act='leaky_relu') + x = fc(x, 1, act=None) + return x + + +def G(x): + x = bn(fc(x, gfc_dim)) + x = bn(fc(x, gf_dim * 2 * img_dim / 4 * img_dim / 4)) + x = fluid.layers.reshape(x, [-1, gf_dim * 2, img_dim / 4, img_dim / 4]) + x = deconv(x, gf_dim * 2, act='relu', output_size=[14, 14]) + x = deconv(x, 1, filter_size=5, padding=2, act='tanh', output_size=[28, 28]) + x = fluid.layers.reshape(x, shape=[-1, 28 * 28]) + return x diff --git a/fluid/gan/c_gan/utility.py b/fluid/gan/c_gan/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cd4711b555a9947634f5c0d205ebff8cc77b8e --- /dev/null +++ b/fluid/gan/c_gan/utility.py @@ -0,0 +1,79 @@ +import math +import distutils.util +import numpy as np +import inspect +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + +img_dim = 28 + + +def get_parent_function_name(): + return inspect.stack()[2][3] + '.' + inspect.stack()[1][3] + '.' + str( + inspect.stack()[2][2]) + '.' + + +def plot(gen_data): + pad_dim = 1 + paded = pad_dim + img_dim + gen_data = gen_data.reshape(gen_data.shape[0], img_dim, img_dim) + n = int(math.ceil(math.sqrt(gen_data.shape[0]))) + gen_data = (np.pad( + gen_data, [[0, n * n - gen_data.shape[0]], [pad_dim, 0], [pad_dim, 0]], + 'constant').reshape((n, n, paded, paded)).transpose((0, 2, 1, 3)) + .reshape((n * paded, n * paded))) + fig = plt.figure(figsize=(8, 8)) + plt.axis('off') + plt.imshow(gen_data, cmap='Greys_r', vmin=-1, vmax=1) + return fig + + +def check(a): + a = np.sort(np.array(a).flatten()) + return [ + np.average(a), np.min(a), np.max(a), a[int(len(a) * 0.25)], + a[int(len(a) * 0.75)] + ] + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----------- Configuration Arguments -----------") + for arg, value in sorted(vars(args).iteritems()): + print("%s: %s" % (arg, value)) + print("------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) diff --git a/fluid/cycle_gan/README.md b/fluid/gan/cycle_gan/README.md similarity index 100% rename from fluid/cycle_gan/README.md rename to fluid/gan/cycle_gan/README.md diff --git a/fluid/cycle_gan/data_reader.py b/fluid/gan/cycle_gan/data_reader.py similarity index 100% rename from fluid/cycle_gan/data_reader.py rename to fluid/gan/cycle_gan/data_reader.py diff --git a/fluid/cycle_gan/images/A2B.jpg b/fluid/gan/cycle_gan/images/A2B.jpg similarity index 100% rename from fluid/cycle_gan/images/A2B.jpg rename to fluid/gan/cycle_gan/images/A2B.jpg diff --git a/fluid/cycle_gan/images/B2A.jpg b/fluid/gan/cycle_gan/images/B2A.jpg similarity index 100% rename from fluid/cycle_gan/images/B2A.jpg rename to fluid/gan/cycle_gan/images/B2A.jpg diff --git a/fluid/cycle_gan/images/cycleGAN_loss.png b/fluid/gan/cycle_gan/images/cycleGAN_loss.png similarity index 100% rename from fluid/cycle_gan/images/cycleGAN_loss.png rename to fluid/gan/cycle_gan/images/cycleGAN_loss.png diff --git a/fluid/cycle_gan/infer.py b/fluid/gan/cycle_gan/infer.py similarity index 100% rename from fluid/cycle_gan/infer.py rename to fluid/gan/cycle_gan/infer.py diff --git a/fluid/cycle_gan/layers.py b/fluid/gan/cycle_gan/layers.py similarity index 100% rename from fluid/cycle_gan/layers.py rename to fluid/gan/cycle_gan/layers.py diff --git a/fluid/cycle_gan/model.py b/fluid/gan/cycle_gan/model.py similarity index 100% rename from fluid/cycle_gan/model.py rename to fluid/gan/cycle_gan/model.py diff --git a/fluid/cycle_gan/train.py b/fluid/gan/cycle_gan/train.py similarity index 100% rename from fluid/cycle_gan/train.py rename to fluid/gan/cycle_gan/train.py diff --git a/fluid/cycle_gan/trainer.py b/fluid/gan/cycle_gan/trainer.py similarity index 100% rename from fluid/cycle_gan/trainer.py rename to fluid/gan/cycle_gan/trainer.py diff --git a/fluid/cycle_gan/utility.py b/fluid/gan/cycle_gan/utility.py similarity index 100% rename from fluid/cycle_gan/utility.py rename to fluid/gan/cycle_gan/utility.py