DCGAN.py 9.0 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from network.DCGAN_network import DCGAN_model
from util import utility

import sys
import six
import os
import numpy as np
import time
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import paddle.fluid as fluid
u010070587's avatar
u010070587 已提交
30
import random
L
lvmengsi 已提交
31 32 33 34 35 36


class GTrainer():
    def __init__(self, input, label, cfg):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
L
lvmengsi 已提交
37
            model = DCGAN_model(cfg.batch_size)
L
lvmengsi 已提交
38
            self.fake = model.network_G(input, name='G')
L
lvmengsi 已提交
39
            self.fake.persistable = True
L
lvmengsi 已提交
40
            self.infer_program = self.program.clone(for_test=True)
L
lvmengsi 已提交
41 42 43 44 45 46
            d_fake = model.network_D(self.fake, name="D")
            fake_labels = fluid.layers.fill_constant_batch_size_like(
                input, dtype='float32', shape=[-1, 1], value=1.0)
            self.g_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake, label=fake_labels))
L
lvmengsi 已提交
47
            self.g_loss.persistable = True
L
lvmengsi 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61

            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and (var.name.startswith("G")):
                    vars.append(var.name)
            optimizer = fluid.optimizer.Adam(
                learning_rate=cfg.learning_rate, beta1=0.5, name="net_G")
            optimizer.minimize(self.g_loss, parameter_list=vars)


class DTrainer():
    def __init__(self, input, labels, cfg):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
L
lvmengsi 已提交
62
            model = DCGAN_model(cfg.batch_size)
L
lvmengsi 已提交
63 64 65 66
            d_logit = model.network_D(input, name="D")
            self.d_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_logit, label=labels))
L
lvmengsi 已提交
67
            self.d_loss.persistable = True
L
lvmengsi 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81
            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and (var.name.startswith("D")):
                    vars.append(var.name)

            optimizer = fluid.optimizer.Adam(
                learning_rate=cfg.learning_rate, beta1=0.5, name="net_D")
            optimizer.minimize(self.d_loss, parameter_list=vars)


class DCGAN(object):
    def add_special_args(self, parser):
        parser.add_argument(
            '--noise_size', type=int, default=100, help="the noise dimension")
u010070587's avatar
u010070587 已提交
82 83 84 85
        parser.add_argument(
            '--enable_ce',
            action='store_true',
            help="if set, run the tasks with continuous evaluation logs")
L
lvmengsi 已提交
86 87
        return parser

L
lvmengsi 已提交
88
    def __init__(self, cfg=None, train_reader=None):
L
lvmengsi 已提交
89 90 91 92
        self.cfg = cfg
        self.train_reader = train_reader

    def build_model(self):
L
lvmengsi 已提交
93 94 95 96
        img = fluid.data(name='img', shape=[None, 784], dtype='float32')
        noise = fluid.data(
            name='noise', shape=[None, self.cfg.noise_size], dtype='float32')
        label = fluid.data(name='label', shape=[None, 1], dtype='float32')
u010070587's avatar
u010070587 已提交
97 98 99 100 101
        # used for continuous evaluation
        if self.cfg.enable_ce:
            fluid.default_startup_program().random_seed = 90
            random.seed(0)
            np.random.seed(0)
L
lvmengsi 已提交
102 103 104 105 106

        g_trainer = GTrainer(noise, label, self.cfg)
        d_trainer = DTrainer(img, label, self.cfg)

        # prepare enviorment
L
lvmengsi 已提交
107
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
L
lvmengsi 已提交
108 109 110 111 112 113 114 115 116 117 118
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        const_n = np.random.uniform(
            low=-1.0, high=1.0,
            size=[self.cfg.batch_size, self.cfg.noise_size]).astype('float32')

        if self.cfg.init_model:
            utility.init_checkpoints(self.cfg, exe, g_trainer, "net_G")
            utility.init_checkpoints(self.cfg, exe, d_trainer, "net_D")

L
lvmengsi 已提交
119
        ### memory optim
L
lvmengsi 已提交
120 121 122 123 124 125 126 127 128 129
        build_strategy = fluid.BuildStrategy()
        build_strategy.enable_inplace = True

        g_trainer_program = fluid.CompiledProgram(
            g_trainer.program).with_data_parallel(
                loss_name=g_trainer.g_loss.name, build_strategy=build_strategy)
        d_trainer_program = fluid.CompiledProgram(
            d_trainer.program).with_data_parallel(
                loss_name=d_trainer.d_loss.name, build_strategy=build_strategy)

L
lvmengsi 已提交
130
        if self.cfg.run_test:
L
lvmengsi 已提交
131
            image_path = os.path.join(self.cfg.output, 'test')
L
lvmengsi 已提交
132 133 134
            if not os.path.exists(image_path):
                os.makedirs(image_path)

L
lvmengsi 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
        t_time = 0
        for epoch_id in range(self.cfg.epoch):
            for batch_id, data in enumerate(self.train_reader()):
                if len(data) != self.cfg.batch_size:
                    continue

                noise_data = np.random.uniform(
                    low=-1.0,
                    high=1.0,
                    size=[self.cfg.batch_size, self.cfg.noise_size]).astype(
                        'float32')
                real_image = np.array(list(map(lambda x: x[0], data))).reshape(
                    [-1, 784]).astype('float32')
                real_label = np.ones(
                    shape=[real_image.shape[0], 1], dtype='float32')
                fake_label = np.zeros(
                    shape=[real_image.shape[0], 1], dtype='float32')
                s_time = time.time()

L
lvmengsi 已提交
154
                generate_image = exe.run(g_trainer_program,
L
lvmengsi 已提交
155 156 157 158 159 160 161 162 163 164
                                         feed={'noise': noise_data},
                                         fetch_list=[g_trainer.fake])

                d_real_loss = exe.run(
                    d_trainer_program,
                    feed={'img': real_image,
                          'label': real_label},
                    fetch_list=[d_trainer.d_loss])[0]
                d_fake_loss = exe.run(
                    d_trainer_program,
L
lvmengsi 已提交
165
                    feed={'img': generate_image[0],
L
lvmengsi 已提交
166 167 168 169 170
                          'label': fake_label},
                    fetch_list=[d_trainer.d_loss])[0]
                d_loss = d_real_loss + d_fake_loss

                for _ in six.moves.xrange(self.cfg.num_generator_time):
L
lvmengsi 已提交
171 172 173 174 175
                    noise_data = np.random.uniform(
                        low=-1.0,
                        high=1.0,
                        size=[self.cfg.batch_size, self.cfg.noise_size]).astype(
                            'float32')
L
lvmengsi 已提交
176 177 178 179 180 181 182
                    g_loss = exe.run(g_trainer_program,
                                     feed={'noise': noise_data},
                                     fetch_list=[g_trainer.g_loss])[0]

                batch_time = time.time() - s_time

                if batch_id % self.cfg.print_freq == 0:
L
lvmengsi 已提交
183 184 185 186 187 188 189 190
                    print(
                        'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'.
                        format(epoch_id, batch_id, d_loss[0], g_loss[0],
                               batch_time))

                t_time += batch_time

                if self.cfg.run_test:
L
lvmengsi 已提交
191 192 193
                    generate_const_image = exe.run(
                        g_trainer.infer_program,
                        feed={'noise': const_n},
194
                        fetch_list=[g_trainer.fake])[0]
L
lvmengsi 已提交
195 196 197 198 199 200

                    generate_image_reshape = np.reshape(generate_const_image, (
                        self.cfg.batch_size, -1))
                    total_images = np.concatenate(
                        [real_image, generate_image_reshape])
                    fig = utility.plot(total_images)
201

L
lvmengsi 已提交
202 203
                    plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
                                                                batch_id))
L
lvmengsi 已提交
204
                    img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
L
lvmengsi 已提交
205
                    plt.savefig(
L
lvmengsi 已提交
206
                        os.path.join(image_path, img_name), bbox_inches='tight')
L
lvmengsi 已提交
207 208 209 210 211
                    plt.close(fig)

            if self.cfg.save_checkpoints:
                utility.checkpoints(epoch_id, self.cfg, exe, g_trainer, "net_G")
                utility.checkpoints(epoch_id, self.cfg, exe, d_trainer, "net_D")
u010070587's avatar
u010070587 已提交
212 213 214 215 216 217 218 219
        # used for continuous evaluation
        if self.cfg.enable_ce:
            device_num = fluid.core.get_cuda_device_count(
            ) if self.cfg.use_gpu else 1
            print("kpis\tdcgan_d_loss_card{}\t{}".format(device_num, d_loss[0]))
            print("kpis\tdcgan_g_loss_card{}\t{}".format(device_num, g_loss[0]))
            print("kpis\tdcgan_Batch_time_cost_card{}\t{}".format(device_num,
                                                                  batch_time))