DCGAN.py 8.2 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from network.DCGAN_network import DCGAN_model
from util import utility

import sys
import six
import os
import numpy as np
import time
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import paddle.fluid as fluid


class GTrainer():
    def __init__(self, input, label, cfg):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
L
lvmengsi 已提交
36
            model = DCGAN_model(cfg.batch_size)
L
lvmengsi 已提交
37
            self.fake = model.network_G(input, name='G')
L
lvmengsi 已提交
38
            self.fake.persistable = True
L
lvmengsi 已提交
39
            self.infer_program = self.program.clone(for_test=True)
L
lvmengsi 已提交
40 41 42 43 44 45
            d_fake = model.network_D(self.fake, name="D")
            fake_labels = fluid.layers.fill_constant_batch_size_like(
                input, dtype='float32', shape=[-1, 1], value=1.0)
            self.g_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake, label=fake_labels))
L
lvmengsi 已提交
46
            self.g_loss.persistable = True
L
lvmengsi 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60

            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and (var.name.startswith("G")):
                    vars.append(var.name)
            optimizer = fluid.optimizer.Adam(
                learning_rate=cfg.learning_rate, beta1=0.5, name="net_G")
            optimizer.minimize(self.g_loss, parameter_list=vars)


class DTrainer():
    def __init__(self, input, labels, cfg):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
L
lvmengsi 已提交
61
            model = DCGAN_model(cfg.batch_size)
L
lvmengsi 已提交
62 63 64 65
            d_logit = model.network_D(input, name="D")
            self.d_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_logit, label=labels))
L
lvmengsi 已提交
66
            self.d_loss.persistable = True
L
lvmengsi 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and (var.name.startswith("D")):
                    vars.append(var.name)

            optimizer = fluid.optimizer.Adam(
                learning_rate=cfg.learning_rate, beta1=0.5, name="net_D")
            optimizer.minimize(self.d_loss, parameter_list=vars)


class DCGAN(object):
    def add_special_args(self, parser):
        parser.add_argument(
            '--noise_size', type=int, default=100, help="the noise dimension")

        return parser

L
lvmengsi 已提交
84
    def __init__(self, cfg=None, train_reader=None):
L
lvmengsi 已提交
85 86 87 88
        self.cfg = cfg
        self.train_reader = train_reader

    def build_model(self):
L
lvmengsi 已提交
89
        img = fluid.data(name='img', shape=[None, 784], dtype='float32')
90
        noise = fluid.data(
L
lvmengsi 已提交
91 92
            name='noise', shape=[None, self.cfg.noise_size], dtype='float32')
        label = fluid.data(name='label', shape=[None, 1], dtype='float32')
L
lvmengsi 已提交
93 94 95 96 97

        g_trainer = GTrainer(noise, label, self.cfg)
        d_trainer = DTrainer(img, label, self.cfg)

        # prepare enviorment
L
lvmengsi 已提交
98
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
L
lvmengsi 已提交
99 100 101 102 103 104 105 106 107 108 109
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        const_n = np.random.uniform(
            low=-1.0, high=1.0,
            size=[self.cfg.batch_size, self.cfg.noise_size]).astype('float32')

        if self.cfg.init_model:
            utility.init_checkpoints(self.cfg, exe, g_trainer, "net_G")
            utility.init_checkpoints(self.cfg, exe, d_trainer, "net_D")

L
lvmengsi 已提交
110
        ### memory optim
L
lvmengsi 已提交
111 112 113 114 115 116 117 118 119 120
        build_strategy = fluid.BuildStrategy()
        build_strategy.enable_inplace = True

        g_trainer_program = fluid.CompiledProgram(
            g_trainer.program).with_data_parallel(
                loss_name=g_trainer.g_loss.name, build_strategy=build_strategy)
        d_trainer_program = fluid.CompiledProgram(
            d_trainer.program).with_data_parallel(
                loss_name=d_trainer.d_loss.name, build_strategy=build_strategy)

L
lvmengsi 已提交
121
        if self.cfg.run_test:
L
lvmengsi 已提交
122
            image_path = os.path.join(self.cfg.output, 'test')
L
lvmengsi 已提交
123 124 125
            if not os.path.exists(image_path):
                os.makedirs(image_path)

L
lvmengsi 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        t_time = 0
        for epoch_id in range(self.cfg.epoch):
            for batch_id, data in enumerate(self.train_reader()):
                if len(data) != self.cfg.batch_size:
                    continue

                noise_data = np.random.uniform(
                    low=-1.0,
                    high=1.0,
                    size=[self.cfg.batch_size, self.cfg.noise_size]).astype(
                        'float32')
                real_image = np.array(list(map(lambda x: x[0], data))).reshape(
                    [-1, 784]).astype('float32')
                real_label = np.ones(
                    shape=[real_image.shape[0], 1], dtype='float32')
                fake_label = np.zeros(
                    shape=[real_image.shape[0], 1], dtype='float32')
                s_time = time.time()

L
lvmengsi 已提交
145
                generate_image = exe.run(g_trainer_program,
L
lvmengsi 已提交
146 147 148 149 150 151 152 153 154 155
                                         feed={'noise': noise_data},
                                         fetch_list=[g_trainer.fake])

                d_real_loss = exe.run(
                    d_trainer_program,
                    feed={'img': real_image,
                          'label': real_label},
                    fetch_list=[d_trainer.d_loss])[0]
                d_fake_loss = exe.run(
                    d_trainer_program,
L
lvmengsi 已提交
156
                    feed={'img': generate_image[0],
L
lvmengsi 已提交
157 158 159 160 161
                          'label': fake_label},
                    fetch_list=[d_trainer.d_loss])[0]
                d_loss = d_real_loss + d_fake_loss

                for _ in six.moves.xrange(self.cfg.num_generator_time):
L
lvmengsi 已提交
162 163 164 165 166
                    noise_data = np.random.uniform(
                        low=-1.0,
                        high=1.0,
                        size=[self.cfg.batch_size, self.cfg.noise_size]).astype(
                            'float32')
L
lvmengsi 已提交
167 168 169 170 171 172 173
                    g_loss = exe.run(g_trainer_program,
                                     feed={'noise': noise_data},
                                     fetch_list=[g_trainer.g_loss])[0]

                batch_time = time.time() - s_time

                if batch_id % self.cfg.print_freq == 0:
L
lvmengsi 已提交
174 175 176 177 178 179 180 181
                    print(
                        'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'.
                        format(epoch_id, batch_id, d_loss[0], g_loss[0],
                               batch_time))

                t_time += batch_time

                if self.cfg.run_test:
L
lvmengsi 已提交
182 183 184
                    generate_const_image = exe.run(
                        g_trainer.infer_program,
                        feed={'noise': const_n},
185
                        fetch_list=[g_trainer.fake])[0]
L
lvmengsi 已提交
186 187 188 189 190 191

                    generate_image_reshape = np.reshape(generate_const_image, (
                        self.cfg.batch_size, -1))
                    total_images = np.concatenate(
                        [real_image, generate_image_reshape])
                    fig = utility.plot(total_images)
192

L
lvmengsi 已提交
193 194
                    plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
                                                                batch_id))
L
lvmengsi 已提交
195
                    img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
L
lvmengsi 已提交
196
                    plt.savefig(
L
lvmengsi 已提交
197
                        os.path.join(image_path, img_name), bbox_inches='tight')
L
lvmengsi 已提交
198 199 200 201 202
                    plt.close(fig)

            if self.cfg.save_checkpoints:
                utility.checkpoints(epoch_id, self.cfg, exe, g_trainer, "net_G")
                utility.checkpoints(epoch_id, self.cfg, exe, d_trainer, "net_D")