c_gan.py 7.6 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
whs 已提交
15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
whs 已提交
18 19
import sys
import os
W
whs 已提交
20
import six
W
whs 已提交
21 22 23 24 25
import argparse
import functools
import matplotlib
import numpy as np
import paddle
W
whs 已提交
26
import time
W
whs 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
import paddle.fluid as fluid
from utility import get_parent_function_name, plot, check, add_arguments, print_arguments
from network import G_cond, D_cond
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

NOISE_SIZE = 100
LEARNING_RATE = 2e-4

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',        int,   121,          "Minibatch size.")
add_arg('epoch',             int,   20,        "The number of epoched to be trained.")
add_arg('output',            str,   "./output", "The directory the model and the test result to be saved to.")
add_arg('use_gpu',           bool,  True,       "Whether to use GPU to train.")
W
whs 已提交
44
add_arg('run_ce',            bool,  False,       "Whether to run for model ce.")
W
whs 已提交
45 46 47 48
# yapf: enable


def loss(x, label):
49 50 51
    return fluid.layers.mean(
        fluid.layers.sigmoid_cross_entropy_with_logits(
            x=x, label=label))
W
whs 已提交
52 53 54 55


def train(args):

W
whs 已提交
56 57 58 59
    if args.run_ce:
        np.random.seed(10)
        fluid.default_startup_program().random_seed = 90

W
whs 已提交
60 61 62 63
    d_program = fluid.Program()
    dg_program = fluid.Program()

    with fluid.program_guard(d_program):
C
ceci3 已提交
64 65 66 67
        conditions = fluid.data(
            name='conditions', shape=[None, 1], dtype='float32')
        img = fluid.data(name='img', shape=[None, 784], dtype='float32')
        label = fluid.data(name='label', shape=[None, 1], dtype='float32')
W
whs 已提交
68 69 70 71
        d_logit = D_cond(img, conditions)
        d_loss = loss(d_logit, label)

    with fluid.program_guard(dg_program):
C
ceci3 已提交
72 73 74 75
        conditions = fluid.data(
            name='conditions', shape=[None, 1], dtype='float32')
        noise = fluid.data(
            name='noise', shape=[None, NOISE_SIZE], dtype='float32')
W
whs 已提交
76 77 78 79 80 81
        g_img = G_cond(z=noise, y=conditions)

        g_program = dg_program.clone()
        g_program_test = dg_program.clone(for_test=True)

        dg_logit = D_cond(g_img, conditions)
C
ceci3 已提交
82
        dg_logit_shape = fluid.layers.shape(dg_logit)
83 84
        dg_loss = loss(
            dg_logit,
C
ceci3 已提交
85 86
            fluid.layers.fill_constant(
                dtype='float32', shape=[dg_logit_shape[0], 1], value=1.0))
W
whs 已提交
87 88 89 90 91 92 93 94 95 96 97 98

    opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)

    opt.minimize(loss=d_loss)
    parameters = [p.name for p in g_program.global_block().all_parameters()]

    opt.minimize(loss=dg_loss, parameter_list=parameters)

    exe = fluid.Executor(fluid.CPUPlace())
    if args.use_gpu:
        exe = fluid.Executor(fluid.CUDAPlace(0))
    exe.run(fluid.default_startup_program())
W
whs 已提交
99
    if args.run_ce:
C
ceci3 已提交
100
        train_reader = fluid.io.batch(
L
lvmengsi 已提交
101
            paddle.dataset.mnist.train(), batch_size=args.batch_size)
W
whs 已提交
102
    else:
C
ceci3 已提交
103 104
        train_reader = fluid.io.batch(
            fluid.io.shuffle(
W
whs 已提交
105 106
                paddle.dataset.mnist.train(), buf_size=60000),
            batch_size=args.batch_size)
W
whs 已提交
107 108 109 110 111

    NUM_TRAIN_TIMES_OF_DG = 2
    const_n = np.random.uniform(
        low=-1.0, high=1.0,
        size=[args.batch_size, NOISE_SIZE]).astype('float32')
W
whs 已提交
112
    t_time = 0
L
lvmengsi 已提交
113
    losses = [[], []]
W
whs 已提交
114 115 116 117 118 119 120
    for pass_id in range(args.epoch):
        for batch_id, data in enumerate(train_reader()):
            if len(data) != args.batch_size:
                continue
            noise_data = np.random.uniform(
                low=-1.0, high=1.0,
                size=[args.batch_size, NOISE_SIZE]).astype('float32')
W
whs 已提交
121
            real_image = np.array(list(map(lambda x: x[0], data))).reshape(
W
whs 已提交
122 123 124 125 126 127 128 129
                -1, 784).astype('float32')
            conditions_data = np.array([x[1] for x in data]).reshape(
                [-1, 1]).astype("float32")
            real_labels = np.ones(
                shape=[real_image.shape[0], 1], dtype='float32')
            fake_labels = np.zeros(
                shape=[real_image.shape[0], 1], dtype='float32')
            total_label = np.concatenate([real_labels, fake_labels])
W
whs 已提交
130
            s_time = time.time()
W
whs 已提交
131 132 133 134
            generated_image = exe.run(
                g_program,
                feed={'noise': noise_data,
                      'conditions': conditions_data},
L
lvmengsi 已提交
135
                fetch_list=[g_img])[0]
W
whs 已提交
136 137 138 139 140 141 142 143 144

            total_images = np.concatenate([real_image, generated_image])

            d_loss_1 = exe.run(d_program,
                               feed={
                                   'img': generated_image,
                                   'label': fake_labels,
                                   'conditions': conditions_data
                               },
L
lvmengsi 已提交
145
                               fetch_list=[d_loss])[0][0]
W
whs 已提交
146 147 148 149 150 151 152

            d_loss_2 = exe.run(d_program,
                               feed={
                                   'img': real_image,
                                   'label': real_labels,
                                   'conditions': conditions_data
                               },
L
lvmengsi 已提交
153
                               fetch_list=[d_loss])[0][0]
W
whs 已提交
154

W
whs 已提交
155 156
            d_loss_n = d_loss_1 + d_loss_2
            losses[0].append(d_loss_n)
W
whs 已提交
157
            for _ in six.moves.xrange(NUM_TRAIN_TIMES_OF_DG):
W
whs 已提交
158 159 160
                noise_data = np.random.uniform(
                    low=-1.0, high=1.0,
                    size=[args.batch_size, NOISE_SIZE]).astype('float32')
W
whs 已提交
161
                dg_loss_n = exe.run(
W
whs 已提交
162 163 164
                    dg_program,
                    feed={'noise': noise_data,
                          'conditions': conditions_data},
L
lvmengsi 已提交
165
                    fetch_list=[dg_loss])[0][0]
W
whs 已提交
166
                losses[1].append(dg_loss_n)
167 168
            batch_time = time.time() - s_time
            t_time += batch_time
W
whs 已提交
169 170

            if batch_id % 10 == 0 and not args.run_ce:
W
whs 已提交
171 172 173 174 175 176 177
                if not os.path.exists(args.output):
                    os.makedirs(args.output)
                # generate image each batch
                generated_images = exe.run(
                    g_program_test,
                    feed={'noise': const_n,
                          'conditions': conditions_data},
L
lvmengsi 已提交
178
                    fetch_list=[g_img])[0]
W
whs 已提交
179 180
                total_images = np.concatenate([real_image, generated_images])
                fig = plot(total_images)
181 182 183
                msg = "Epoch ID={0}\n Batch ID={1}\n D-Loss={2}\n DG-Loss={3}\n gen={4}\n " \
                      "Batch_time_cost={5:.2f}".format(
                    pass_id, batch_id, d_loss_n, dg_loss_n, check(generated_images), batch_time)
W
whs 已提交
184 185 186 187 188 189 190 191
                print(msg)
                plt.title(msg)
                plt.savefig(
                    '{}/{:04d}_{:04d}.png'.format(args.output, pass_id,
                                                  batch_id),
                    bbox_inches='tight')
                plt.close(fig)

W
whs 已提交
192 193 194 195
    if args.run_ce:
        print("kpis,cgan_d_train_cost,{}".format(np.mean(losses[0])))
        print("kpis,cgan_g_train_cost,{}".format(np.mean(losses[1])))
        print("kpis,cgan_duration,{}".format(t_time / args.epoch))
L
lvmengsi 已提交
196

W
whs 已提交
197 198 199 200 201

if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
    train(args)