fc_gan.py 5.2 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
import errno
import math
import os

import matplotlib
import numpy

import paddle.v2 as paddle
import paddle.v2.fluid as fluid

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

NOISE_SIZE = 100
NUM_PASS = 1000
NUM_REAL_IMGS_IN_BATCH = 121
NUM_TRAIN_TIMES_OF_DG = 3
LEARNING_RATE = 2e-5


def D(x):
    hidden = fluid.layers.fc(input=x,
                             size=200,
                             act='relu',
                             param_attr='D.w1',
                             bias_attr='D.b1')
    logits = fluid.layers.fc(input=hidden,
                             size=1,
                             act=None,
                             param_attr='D.w2',
                             bias_attr='D.b2')
    return logits


def G(x):
    hidden = fluid.layers.fc(input=x,
                             size=200,
                             act='relu',
                             param_attr='G.w1',
                             bias_attr='G.b1')
    img = fluid.layers.fc(input=hidden,
                          size=28 * 28,
                          act='tanh',
                          param_attr='G.w2',
                          bias_attr='G.b2')
    return img


def plot(gen_data):
    gen_data.resize(gen_data.shape[0], 28, 28)
    n = int(math.ceil(math.sqrt(gen_data.shape[0])))
    fig = plt.figure(figsize=(n, n))
    gs = gridspec.GridSpec(n, n)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(gen_data):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


def main():
    try:
        os.makedirs("./out")
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    startup_program = fluid.Program()
    d_program = fluid.Program()
    dg_program = fluid.Program()

    with fluid.program_guard(d_program, startup_program):
        img = fluid.layers.data(name='img', shape=[784], dtype='float32')
        d_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
            x=D(img),
            label=fluid.layers.data(
                name='label', shape=[1], dtype='float32'))
        d_loss = fluid.layers.mean(x=d_loss)

    with fluid.program_guard(dg_program, startup_program):
        noise = fluid.layers.data(
            name='noise', shape=[NOISE_SIZE], dtype='float32')
        g_img = G(x=noise)
        g_program = dg_program.clone()
        dg_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
            x=D(g_img),
            label=fluid.layers.fill_constant_batch_size_like(
                input=noise, dtype='float32', shape=[-1, 1], value=1.0))
        dg_loss = fluid.layers.mean(x=dg_loss)

    opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)

    opt.minimize(loss=d_loss, startup_program=startup_program)
    opt.minimize(
        loss=dg_loss,
        startup_program=startup_program,
        parameter_list=[
            p.name for p in g_program.global_block().all_parameters()
        ])
    exe = fluid.Executor(fluid.CPUPlace())
    exe.run(startup_program)

    num_true = NUM_REAL_IMGS_IN_BATCH
    train_reader = paddle.batch(
        paddle.reader.shuffle(
            paddle.dataset.mnist.train(), buf_size=60000),
        batch_size=num_true)

    for pass_id in range(NUM_PASS):
        for batch_id, data in enumerate(train_reader()):
            num_true = len(data)
            n = numpy.random.uniform(
                low=-1.0, high=1.0,
                size=[num_true * NOISE_SIZE]).astype('float32').reshape(
                    [num_true, NOISE_SIZE])
            generated_img = exe.run(g_program,
                                    feed={'noise': n},
                                    fetch_list={g_img})[0]
            real_data = numpy.array(map(lambda x: x[0], data)).astype('float32')
            real_data = real_data.reshape(num_true, 784)
            total_data = numpy.concatenate([real_data, generated_img])
            total_label = numpy.concatenate([
                numpy.ones(
                    shape=[real_data.shape[0], 1], dtype='float32'),
                numpy.zeros(
                    shape=[real_data.shape[0], 1], dtype='float32')
            ])
            d_loss_np = exe.run(d_program,
                                feed={'img': total_data,
                                      'label': total_label},
                                fetch_list={d_loss})[0]
            for _ in xrange(NUM_TRAIN_TIMES_OF_DG):
                n = numpy.random.uniform(
                    low=-1.0, high=1.0,
                    size=[2 * num_true * NOISE_SIZE]).astype('float32').reshape(
                        [2 * num_true, NOISE_SIZE, 1, 1])
                dg_loss_np = exe.run(dg_program,
                                     feed={'noise': n},
                                     fetch_list={dg_loss})[0]
            print("Pass ID={0}, Batch ID={1}, D-Loss={2}, DG-Loss={3}".format(
                pass_id, batch_id, d_loss_np, dg_loss_np))
        # generate image each batch
        fig = plot(generated_img)
        plt.savefig(
            'out/{0}.png'.format(str(pass_id).zfill(3)), bbox_inches='tight')
        plt.close(fig)


if __name__ == '__main__':
    main()