gan_trainer.py 12.4 KB
Newer Older
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
X
xuwei06 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import random
import numpy
W
wangyang59 已提交
18
import cPickle
Y
Yu Yang 已提交
19
import sys, os
W
wangyang59 已提交
20
from PIL import Image
X
xuwei06 已提交
21 22 23 24

from paddle.trainer.config_parser import parse_config
from paddle.trainer.config_parser import logger
import py_paddle.swig_paddle as api
25 26
import matplotlib.pyplot as plt

Y
Yu Yang 已提交
27

28
def plot2DScatter(data, outputfile):
W
wangyang59 已提交
29 30 31 32
    '''
    Plot the data as a 2D scatter plot and save to outputfile
    data needs to be two dimensinoal
    '''
33 34
    x = data[:, 0]
    y = data[:, 1]
35 36
    logger.info("The mean vector is %s" % numpy.mean(data, 0))
    logger.info("The std vector is %s" % numpy.std(data, 0))
37 38 39 40 41 42 43

    heatmap, xedges, yedges = numpy.histogram2d(x, y, bins=50)
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

    plt.clf()
    plt.scatter(x, y)
    plt.savefig(outputfile, bbox_inches='tight')
X
xuwei06 已提交
44

Y
Yu Yang 已提交
45

X
xuwei06 已提交
46 47 48
def CHECK_EQ(a, b):
    assert a == b, "a=%s, b=%s" % (a, b)

Y
Yu Yang 已提交
49

X
xuwei06 已提交
50
def copy_shared_parameters(src, dst):
W
wangyang59 已提交
51 52 53 54 55 56 57
    '''
    copy the parameters from src to dst
    :param src: the source of the parameters
    :type src: GradientMachine
    :param dst: the destination of the parameters
    :type dst: GradientMachine
    '''
Y
Yu Yang 已提交
58
    src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())]
X
xuwei06 已提交
59 60 61 62 63 64 65 66 67 68 69 70
    src_params = dict([(p.getName(), p) for p in src_params])

    for i in xrange(dst.getParameterSize()):
        dst_param = dst.getParameter(i)
        src_param = src_params.get(dst_param.getName(), None)
        if src_param is None:
            continue
        src_value = src_param.getBuf(api.PARAMETER_VALUE)
        dst_value = dst_param.getBuf(api.PARAMETER_VALUE)
        CHECK_EQ(len(src_value), len(dst_value))
        dst_value.copyFrom(src_value)
        dst_param.setValueUpdated()
Y
Yu Yang 已提交
71 72


73
def print_parameters(src):
Y
Yu Yang 已提交
74
    src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())]
X
xuwei06 已提交
75

76 77 78
    print "***************"
    for p in src_params:
        print "Name is %s" % p.getName()
Y
Yu Yang 已提交
79 80 81
        print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray(
        )

X
xuwei06 已提交
82

W
wangyang59 已提交
83 84 85 86 87 88 89 90 91
def load_mnist_data(imageFile):
    f = open(imageFile, "rb")
    f.read(16)

    # Define number of samples for train/test
    if "train" in imageFile:
        n = 60000
    else:
        n = 10000
Y
Yu Yang 已提交
92 93

    data = numpy.fromfile(f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28))
94
    data = data / 255.0 * 2.0 - 1.0
W
wangyang59 已提交
95 96

    f.close()
97
    return data.astype('float32')
W
wangyang59 已提交
98

Y
Yu Yang 已提交
99

W
wangyang59 已提交
100 101
def load_cifar_data(cifar_path):
    batch_size = 10000
Y
Yu Yang 已提交
102
    data = numpy.zeros((5 * batch_size, 32 * 32 * 3), dtype="float32")
W
wangyang59 已提交
103 104 105 106 107
    for i in range(1, 6):
        file = cifar_path + "/data_batch_" + str(i)
        fo = open(file, 'rb')
        dict = cPickle.load(fo)
        fo.close()
Y
Yu Yang 已提交
108 109
        data[(i - 1) * batch_size:(i * batch_size), :] = dict["data"]

W
wangyang59 已提交
110 111 112
    data = data / 255.0 * 2.0 - 1.0
    return data

Y
Yu Yang 已提交
113

W
wangyang59 已提交
114 115 116 117 118
# synthesize 2-D uniform data
def load_uniform_data():
    data = numpy.random.rand(1000000, 2).astype('float32')
    return data

Y
Yu Yang 已提交
119

W
wangyang59 已提交
120
def merge(images, size):
Y
Yu Yang 已提交
121
    if images.shape[1] == 28 * 28:
W
wangyang59 已提交
122 123 124 125 126 127 128 129 130 131 132
        h, w, c = 28, 28, 1
    else:
        h, w, c = 32, 32, 3
    img = numpy.zeros((h * size[0], w * size[1], c))
    for idx in xrange(size[0] * size[1]):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = \
          ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
    return img.astype('uint8')

Y
Yu Yang 已提交
133

134
def save_images(images, path):
W
wangyang59 已提交
135 136 137 138 139 140
    merged_img = merge(images, [8, 8])
    if merged_img.shape[2] == 1:
        im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB')
    else:
        im = Image.fromarray(merged_img, mode="RGB")
    im.save(path)
Y
Yu Yang 已提交
141 142


W
wangyang59 已提交
143
def get_real_samples(batch_size, data_np):
Y
Yu Yang 已提交
144 145 146 147
    return data_np[numpy.random.choice(
        data_np.shape[0], batch_size, replace=False), :]


W
wangyang59 已提交
148 149 150
def get_noise(batch_size, noise_dim):
    return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32')

Y
Yu Yang 已提交
151

W
wangyang59 已提交
152 153 154
def get_fake_samples(generator_machine, batch_size, noise):
    gen_inputs = api.Arguments.createArguments(1)
    gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
X
xuwei06 已提交
155 156 157
    gen_outputs = api.Arguments.createArguments(0)
    generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST)
    fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat()
158 159
    return fake_samples

Y
Yu Yang 已提交
160

161 162 163 164 165 166
def get_training_loss(training_machine, inputs):
    outputs = api.Arguments.createArguments(0)
    training_machine.forward(inputs, outputs, api.PASS_TEST)
    loss = outputs.getSlotValue(0).copyToNumpyMat()
    return numpy.mean(loss)

Y
Yu Yang 已提交
167

W
wangyang59 已提交
168 169
def prepare_discriminator_data_batch_pos(batch_size, data_np):
    real_samples = get_real_samples(batch_size, data_np)
170 171
    labels = numpy.ones(batch_size, dtype='int32')
    inputs = api.Arguments.createArguments(2)
W
wangyang59 已提交
172 173
    inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(real_samples))
    inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels))
174 175
    return inputs

Y
Yu Yang 已提交
176

W
wangyang59 已提交
177 178
def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise):
    fake_samples = get_fake_samples(generator_machine, batch_size, noise)
179 180
    labels = numpy.zeros(batch_size, dtype='int32')
    inputs = api.Arguments.createArguments(2)
W
wangyang59 已提交
181 182
    inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(fake_samples))
    inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels))
183
    return inputs
X
xuwei06 已提交
184

Y
Yu Yang 已提交
185

W
wangyang59 已提交
186
def prepare_generator_data_batch(batch_size, noise):
X
xuwei06 已提交
187 188
    label = numpy.ones(batch_size, dtype='int32')
    inputs = api.Arguments.createArguments(2)
W
wangyang59 已提交
189 190
    inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
    inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(label))
X
xuwei06 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
    return inputs


def find(iterable, cond):
    for item in iterable:
        if cond(item):
            return item
    return None


def get_layer_size(model_conf, layer_name):
    layer_conf = find(model_conf.layers, lambda x: x.name == layer_name)
    assert layer_conf is not None, "Cannot find '%s' layer" % layer_name
    return layer_conf.size


def main():
W
wangyang59 已提交
208
    parser = argparse.ArgumentParser()
209
    parser.add_argument("-d", "--data_source", help="mnist or cifar or uniform")
Y
Yu Yang 已提交
210 211 212
    parser.add_argument(
        "--use_gpu", default="1", help="1 means use gpu for training")
    parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter")
W
wangyang59 已提交
213
    args = parser.parse_args()
214 215 216 217
    data_source = args.data_source
    use_gpu = args.use_gpu
    assert data_source in ["mnist", "cifar", "uniform"]
    assert use_gpu in ["0", "1"]
218

219 220
    if not os.path.exists("./%s_samples/" % data_source):
        os.makedirs("./%s_samples/" % data_source)
221

222 223
    if not os.path.exists("./%s_params/" % data_source):
        os.makedirs("./%s_params/" % data_source)
Y
Yu Yang 已提交
224 225 226 227 228

    api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10',
                   '--log_period=100', '--gpu_id=' + args.gpu_id,
                   '--save_dir=' + "./%s_params/" % data_source)

229
    if data_source == "uniform":
W
wangyang59 已提交
230 231 232 233 234
        conf = "gan_conf.py"
        num_iter = 10000
    else:
        conf = "gan_conf_image.py"
        num_iter = 1000
Y
Yu Yang 已提交
235

236
    gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source)
Y
Yu Yang 已提交
237 238
    dis_conf = parse_config(conf,
                            "mode=discriminator_training,data=" + data_source)
239
    generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
X
xuwei06 已提交
240 241
    batch_size = dis_conf.opt_config.batch_size
    noise_dim = get_layer_size(gen_conf.model_config, "noise")
Y
Yu Yang 已提交
242

243
    if data_source == "mnist":
W
wangyang59 已提交
244
        data_np = load_mnist_data("./data/mnist_data/train-images-idx3-ubyte")
245
    elif data_source == "cifar":
W
wangyang59 已提交
246 247 248
        data_np = load_cifar_data("./data/cifar-10-batches-py/")
    else:
        data_np = load_uniform_data()
Y
Yu Yang 已提交
249

250
    # this creates a gradient machine for discriminator
X
xuwei06 已提交
251 252
    dis_training_machine = api.GradientMachine.createFromConfigProto(
        dis_conf.model_config)
W
wangyang59 已提交
253
    # this create a gradient machine for generator    
X
xuwei06 已提交
254 255 256 257
    gen_training_machine = api.GradientMachine.createFromConfigProto(
        gen_conf.model_config)

    # generator_machine is used to generate data only, which is used for
258
    # training discriminator
X
xuwei06 已提交
259 260 261 262
    logger.info(str(generator_conf.model_config))
    generator_machine = api.GradientMachine.createFromConfigProto(
        generator_conf.model_config)

Y
Yu Yang 已提交
263 264 265 266
    dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)

    gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)

X
xuwei06 已提交
267 268
    dis_trainer.startTrain()
    gen_trainer.startTrain()
Y
Yu Yang 已提交
269

W
wangyang59 已提交
270
    # Sync parameters between networks (GradientMachine) at the beginning
271 272
    copy_shared_parameters(gen_training_machine, dis_training_machine)
    copy_shared_parameters(gen_training_machine, generator_machine)
Y
Yu Yang 已提交
273

W
wangyang59 已提交
274 275
    # constrain that either discriminator or generator can not be trained
    # consecutively more than MAX_strike times
276 277 278
    curr_train = "dis"
    curr_strike = 0
    MAX_strike = 5
Y
Yu Yang 已提交
279

280
    for train_pass in xrange(100):
X
xuwei06 已提交
281 282
        dis_trainer.startTrainPass()
        gen_trainer.startTrainPass()
W
wangyang59 已提交
283 284 285
        for i in xrange(num_iter):
            # Do forward pass in discriminator to get the dis_loss
            noise = get_noise(batch_size, noise_dim)
286
            data_batch_dis_pos = prepare_discriminator_data_batch_pos(
W
wangyang59 已提交
287
                batch_size, data_np)
Y
Yu Yang 已提交
288 289 290
            dis_loss_pos = get_training_loss(dis_training_machine,
                                             data_batch_dis_pos)

291
            data_batch_dis_neg = prepare_discriminator_data_batch_neg(
W
wangyang59 已提交
292
                generator_machine, batch_size, noise)
Y
Yu Yang 已提交
293 294 295
            dis_loss_neg = get_training_loss(dis_training_machine,
                                             data_batch_dis_neg)

296
            dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0
Y
Yu Yang 已提交
297

W
wangyang59 已提交
298
            # Do forward pass in generator to get the gen_loss
Y
Yu Yang 已提交
299
            data_batch_gen = prepare_generator_data_batch(batch_size, noise)
300
            gen_loss = get_training_loss(gen_training_machine, data_batch_gen)
Y
Yu Yang 已提交
301

W
wangyang59 已提交
302
            if i % 100 == 0:
Y
Yu Yang 已提交
303 304
                print "d_pos_loss is %s     d_neg_loss is %s" % (dis_loss_pos,
                                                                 dis_loss_neg)
305
                print "d_loss is %s    g_loss is %s" % (dis_loss, gen_loss)
Y
Yu Yang 已提交
306

W
wangyang59 已提交
307 308 309 310
            # Decide which network to train based on the training history
            # And the relative size of the loss        
            if (not (curr_train == "dis" and curr_strike == MAX_strike)) and \
               ((curr_train == "gen" and curr_strike == MAX_strike) or dis_loss > gen_loss):
311 312 313 314
                if curr_train == "dis":
                    curr_strike += 1
                else:
                    curr_train = "dis"
Y
Yu Yang 已提交
315
                    curr_strike = 1
316
                dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg)
Y
Yu Yang 已提交
317 318 319 320
                dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos)
                copy_shared_parameters(dis_training_machine,
                                       gen_training_machine)

321 322 323 324 325 326
            else:
                if curr_train == "gen":
                    curr_strike += 1
                else:
                    curr_train = "gen"
                    curr_strike = 1
327 328 329
                gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
                # TODO: add API for paddle to allow true parameter sharing between different GradientMachines 
                # so that we do not need to copy shared parameters. 
Y
Yu Yang 已提交
330 331
                copy_shared_parameters(gen_training_machine,
                                       dis_training_machine)
332
                copy_shared_parameters(gen_training_machine, generator_machine)
Y
Yu Yang 已提交
333

X
xuwei06 已提交
334 335
        dis_trainer.finishTrainPass()
        gen_trainer.finishTrainPass()
W
wangyang59 已提交
336 337
        # At the end of each pass, save the generated samples/images
        fake_samples = get_fake_samples(generator_machine, batch_size, noise)
338
        if data_source == "uniform":
Y
Yu Yang 已提交
339 340
            plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" %
                          (data_source, train_pass))
W
wangyang59 已提交
341
        else:
Y
Yu Yang 已提交
342 343
            save_images(fake_samples, "./%s_samples/train_pass%s.png" %
                        (data_source, train_pass))
X
xuwei06 已提交
344 345 346
    dis_trainer.finishTrain()
    gen_trainer.finishTrain()

Y
Yu Yang 已提交
347

X
xuwei06 已提交
348 349
if __name__ == '__main__':
    main()