Pix2pix.py 15.5 KB
Newer Older
Z
zhumanyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from network.Pix2pix_network import Pix2pix_model
from util import utility
20
from util import timer
Z
zhumanyu 已提交
21
import paddle.fluid as fluid
H
hysunflower 已提交
22
from paddle.fluid import profiler
Z
zhumanyu 已提交
23 24
import sys
import time
C
ceci3 已提交
25
import numpy as np
Z
zhumanyu 已提交
26 27 28 29 30 31 32 33 34 35 36


class GTrainer():
    def __init__(self, input_A, input_B, cfg, step_per_epoch):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
            model = Pix2pix_model()
            self.fake_B = model.network_G(input_A, "generator", cfg=cfg)
            self.infer_program = self.program.clone()
            AB = fluid.layers.concat([input_A, self.fake_B], 1)
            self.pred = model.network_D(AB, "discriminator", cfg)
C
ceci3 已提交
37
            batch = fluid.layers.shape(self.pred)[0]
Z
zhumanyu 已提交
38
            if cfg.gan_mode == "lsgan":
C
ceci3 已提交
39 40
                ones = fluid.layers.fill_constant(
                    shape=[batch] + list(self.pred.shape[1:]),
Z
zhumanyu 已提交
41 42 43 44 45 46 47 48 49 50 51 52
                    value=1,
                    dtype='float32')
                self.g_loss_gan = fluid.layers.reduce_mean(
                    fluid.layers.square(
                        fluid.layers.elementwise_sub(
                            x=self.pred, y=ones)))
            elif cfg.gan_mode == "vanilla":
                pred_shape = self.pred.shape
                self.pred = fluid.layers.reshape(
                    self.pred,
                    [-1, pred_shape[1] * pred_shape[2] * pred_shape[3]],
                    inplace=True)
C
ceci3 已提交
53 54
                ones = fluid.layers.fill_constant(
                    shape=[batch] + list(self.pred.shape[1:]),
Z
zhumanyu 已提交
55 56 57 58 59
                    value=1,
                    dtype='float32')
                self.g_loss_gan = fluid.layers.mean(
                    fluid.layers.sigmoid_cross_entropy_with_logits(
                        x=self.pred, label=ones))
L
lvmengsi 已提交
60 61 62
            else:
                raise NotImplementedError("gan_mode {} is not support!".format(
                    cfg.gan_mode))
Z
zhumanyu 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76

            self.g_loss_L1 = fluid.layers.reduce_mean(
                fluid.layers.abs(
                    fluid.layers.elementwise_sub(
                        x=input_B, y=self.fake_B))) * cfg.lambda_L1
            self.g_loss = fluid.layers.elementwise_add(self.g_loss_L1,
                                                       self.g_loss_gan)
            lr = cfg.learning_rate
            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and var.name.startswith(
                        "generator"):
                    vars.append(var.name)
            self.param = vars
L
lvmengsi 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
            if cfg.epoch <= 100:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=lr, beta1=0.5, beta2=0.999, name="net_G")
            else:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=fluid.layers.piecewise_decay(
                        boundaries=[99 * step_per_epoch] + [
                            x * step_per_epoch
                            for x in range(100, cfg.epoch - 1)
                        ],
                        values=[lr] + [
                            lr * (1.0 - (x - 99.0) / 101.0)
                            for x in range(100, cfg.epoch)
                        ]),
                    beta1=0.5,
                    beta2=0.999,
                    name="net_G")
Z
zhumanyu 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
            optimizer.minimize(self.g_loss, parameter_list=vars)


class DTrainer():
    def __init__(self, input_A, input_B, fake_B, cfg, step_per_epoch):
        self.program = fluid.default_main_program().clone()
        lr = cfg.learning_rate
        with fluid.program_guard(self.program):
            model = Pix2pix_model()
            self.real_AB = fluid.layers.concat([input_A, input_B], 1)
            self.fake_AB = fluid.layers.concat([input_A, fake_B], 1)
            self.pred_real = model.network_D(
                self.real_AB, "discriminator", cfg=cfg)
            self.pred_fake = model.network_D(
                self.fake_AB, "discriminator", cfg=cfg)
C
ceci3 已提交
109
            batch = fluid.layers.shape(input_A)[0]
Z
zhumanyu 已提交
110
            if cfg.gan_mode == "lsgan":
C
ceci3 已提交
111 112
                ones = fluid.layers.fill_constant(
                    shape=[batch] + list(self.pred_real.shape[1:]),
Z
zhumanyu 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
                    value=1,
                    dtype='float32')
                self.d_loss_real = fluid.layers.reduce_mean(
                    fluid.layers.square(
                        fluid.layers.elementwise_sub(
                            x=self.pred_real, y=ones)))
                self.d_loss_fake = fluid.layers.reduce_mean(
                    fluid.layers.square(x=self.pred_fake))
            elif cfg.gan_mode == "vanilla":
                pred_shape = self.pred_real.shape
                self.pred_real = fluid.layers.reshape(
                    self.pred_real,
                    [-1, pred_shape[1] * pred_shape[2] * pred_shape[3]],
                    inplace=True)
                self.pred_fake = fluid.layers.reshape(
                    self.pred_fake,
                    [-1, pred_shape[1] * pred_shape[2] * pred_shape[3]],
                    inplace=True)
C
ceci3 已提交
131 132
                zeros = fluid.layers.fill_constant(
                    shape=[batch] + list(self.pred_fake.shape[1:]),
Z
zhumanyu 已提交
133 134
                    value=0,
                    dtype='float32')
C
ceci3 已提交
135 136
                ones = fluid.layers.fill_constant(
                    shape=[batch] + list(self.pred_real.shape[1:]),
Z
zhumanyu 已提交
137 138 139 140 141 142 143 144
                    value=1,
                    dtype='float32')
                self.d_loss_real = fluid.layers.mean(
                    fluid.layers.sigmoid_cross_entropy_with_logits(
                        x=self.pred_real, label=ones))
                self.d_loss_fake = fluid.layers.mean(
                    fluid.layers.sigmoid_cross_entropy_with_logits(
                        x=self.pred_fake, label=zeros))
L
lvmengsi 已提交
145 146 147 148
            else:
                raise NotImplementedError("gan_mode {} is not support!".format(
                    cfg.gan_mode))

Z
zhumanyu 已提交
149 150 151 152 153 154 155 156
            self.d_loss = 0.5 * (self.d_loss_real + self.d_loss_fake)
            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and var.name.startswith(
                        "discriminator"):
                    vars.append(var.name)

            self.param = vars
L
lvmengsi 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
            if cfg.epoch <= 100:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=lr, beta1=0.5, beta2=0.999, name="net_D")
            else:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=fluid.layers.piecewise_decay(
                        boundaries=[99 * step_per_epoch] + [
                            x * step_per_epoch
                            for x in range(100, cfg.epoch - 1)
                        ],
                        values=[lr] + [
                            lr * (1.0 - (x - 99.0) / 101.0)
                            for x in range(100, cfg.epoch)
                        ]),
                    beta1=0.5,
                    beta2=0.999,
                    name="net_D")
Z
zhumanyu 已提交
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

            optimizer.minimize(self.d_loss, parameter_list=vars)


class Pix2pix(object):
    def add_special_args(self, parser):
        parser.add_argument(
            '--net_G',
            type=str,
            default="unet_256",
            help="Choose the Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]"
        )
        parser.add_argument(
            '--net_D',
            type=str,
            default="basic",
            help="Choose the Pix2pix discriminator's network, choose in [basic|nlayers|pixel]"
        )
        parser.add_argument(
            '--d_nlayers',
            type=int,
            default=3,
            help="only used when Pix2pix discriminator is nlayers")
u010070587's avatar
u010070587 已提交
197 198 199 200
        parser.add_argument(
            '--enable_ce',
            action='store_true',
            help="if set, run the tasks with continuous evaluation logs")
Z
zhumanyu 已提交
201 202 203 204 205 206
        return parser

    def __init__(self,
                 cfg=None,
                 train_reader=None,
                 test_reader=None,
L
lvmengsi 已提交
207 208
                 batch_num=1,
                 id2name=None):
Z
zhumanyu 已提交
209 210 211 212
        self.cfg = cfg
        self.train_reader = train_reader
        self.test_reader = test_reader
        self.batch_num = batch_num
L
lvmengsi 已提交
213
        self.id2name = id2name
Z
zhumanyu 已提交
214 215

    def build_model(self):
L
lvmengsi 已提交
216
        data_shape = [None, 3, self.cfg.crop_size, self.cfg.crop_size]
Z
zhumanyu 已提交
217

L
lvmengsi 已提交
218 219 220
        input_A = fluid.data(name='input_A', shape=data_shape, dtype='float32')
        input_B = fluid.data(name='input_B', shape=data_shape, dtype='float32')
        input_fake = fluid.data(
Z
zhumanyu 已提交
221
            name='input_fake', shape=data_shape, dtype='float32')
u010070587's avatar
u010070587 已提交
222 223 224
        # used for continuous evaluation        
        if self.cfg.enable_ce:
            fluid.default_startup_program().random_seed = 90
Z
zhumanyu 已提交
225

L
lvmengsi 已提交
226
        loader = fluid.io.DataLoader.from_generator(
L
lvmengsi 已提交
227
            feed_list=[input_A, input_B],
L
lvmengsi 已提交
228
            capacity=4,
L
lvmengsi 已提交
229 230 231
            iterable=True,
            use_double_buffer=True)

Z
zhumanyu 已提交
232 233 234 235 236 237
        gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num)
        dis_trainer = DTrainer(input_A, input_B, input_fake, self.cfg,
                               self.batch_num)

        # prepare environment
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
L
lvmengsi 已提交
238
        loader.set_batch_generator(
L
lvmengsi 已提交
239 240 241
            self.train_reader,
            places=fluid.cuda_places()
            if self.cfg.use_gpu else fluid.cpu_places())
Z
zhumanyu 已提交
242 243 244 245
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        if self.cfg.init_model:
C
ceci3 已提交
246 247
            utility.init_checkpoints(self.cfg, gen_trainer, "net_G")
            utility.init_checkpoints(self.cfg, dis_trainer, "net_D")
Z
zhumanyu 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260

        ### memory optim
        build_strategy = fluid.BuildStrategy()

        gen_trainer_program = fluid.CompiledProgram(
            gen_trainer.program).with_data_parallel(
                loss_name=gen_trainer.g_loss.name,
                build_strategy=build_strategy)
        dis_trainer_program = fluid.CompiledProgram(
            dis_trainer.program).with_data_parallel(
                loss_name=dis_trainer.d_loss.name,
                build_strategy=build_strategy)

H
hysunflower 已提交
261
        total_train_batch = 0  # used for benchmark
262 263
        reader_cost_averager = timer.TimeAverager()
        batch_cost_averager = timer.TimeAverager()
Z
zhumanyu 已提交
264 265
        for epoch_id in range(self.cfg.epoch):
            batch_id = 0
266
            batch_start = time.time()
L
lvmengsi 已提交
267
            for tensor in loader():
u010070587's avatar
u010070587 已提交
268
                if self.cfg.max_iter and total_train_batch == self.cfg.max_iter:  # used for benchmark
H
hysunflower 已提交
269
                    return
270
                reader_cost_averager.record(time.time() - batch_start)
L
lvmengsi 已提交
271

Z
zhumanyu 已提交
272 273 274 275 276 277 278
                # optimize the generator network
                g_loss_gan, g_loss_l1, fake_B_tmp = exe.run(
                    gen_trainer_program,
                    fetch_list=[
                        gen_trainer.g_loss_gan, gen_trainer.g_loss_L1,
                        gen_trainer.fake_B
                    ],
L
lvmengsi 已提交
279
                    feed=tensor)
Z
zhumanyu 已提交
280

C
ceci3 已提交
281 282 283
                devices_num = utility.get_device_num(self.cfg)
                fake_per_device = int(len(fake_B_tmp) / devices_num)
                for dev in range(devices_num):
C
ceci3 已提交
284 285
                    tensor[dev]['input_fake'] = fake_B_tmp[
                        dev * fake_per_device:(dev + 1) * fake_per_device]
C
ceci3 已提交
286

Z
zhumanyu 已提交
287 288 289 290 291 292
                # optimize the discriminator network
                d_loss_real, d_loss_fake = exe.run(dis_trainer_program,
                                                   fetch_list=[
                                                       dis_trainer.d_loss_real,
                                                       dis_trainer.d_loss_fake
                                                   ],
C
ceci3 已提交
293
                                                   feed=tensor)
Z
zhumanyu 已提交
294

295
                batch_cost_averager.record(time.time() - batch_start)
Z
zhumanyu 已提交
296 297 298 299
                if batch_id % self.cfg.print_freq == 0:
                    print("epoch{}: batch{}: \n\
                         g_loss_gan: {}; g_loss_l1: {}; \n\
                         d_loss_real: {}; d_loss_fake: {}; \n\
300
                         reader_cost: {}, Batch_time_cost: {}"
Z
zhumanyu 已提交
301
                          .format(epoch_id, batch_id, g_loss_gan[0], g_loss_l1[
302 303 304 305 306
                              0], d_loss_real[0], d_loss_fake[0],
                                  reader_cost_averager.get_average(),
                                  batch_cost_averager.get_average()))
                    reader_cost_averager.reset()
                    batch_cost_averager.reset()
Z
zhumanyu 已提交
307 308 309

                sys.stdout.flush()
                batch_id += 1
H
hysunflower 已提交
310
                total_train_batch += 1  # used for benchmark
311 312
                batch_start = time.time()

H
hysunflower 已提交
313 314 315 316 317
                # profiler tools
                if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
                    profiler.reset_profiler()
                elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
                    return
Z
zhumanyu 已提交
318 319

            if self.cfg.run_test:
L
lvmengsi 已提交
320
                image_name = fluid.data(
L
lvmengsi 已提交
321
                    name='image_name',
L
lvmengsi 已提交
322
                    shape=[None, self.cfg.batch_size],
L
lvmengsi 已提交
323
                    dtype="int32")
L
lvmengsi 已提交
324
                test_loader = fluid.io.DataLoader.from_generator(
L
lvmengsi 已提交
325
                    feed_list=[input_A, input_B, image_name],
L
lvmengsi 已提交
326
                    capacity=4,
L
lvmengsi 已提交
327 328
                    iterable=True,
                    use_double_buffer=True)
L
lvmengsi 已提交
329
                test_loader.set_batch_generator(
L
lvmengsi 已提交
330 331 332
                    self.test_reader,
                    places=fluid.cuda_places()
                    if self.cfg.use_gpu else fluid.cpu_places())
Z
zhumanyu 已提交
333
                test_program = gen_trainer.infer_program
L
lvmengsi 已提交
334 335 336 337 338 339 340
                utility.save_test_image(
                    epoch_id,
                    self.cfg,
                    exe,
                    place,
                    test_program,
                    gen_trainer,
L
lvmengsi 已提交
341
                    test_loader,
L
lvmengsi 已提交
342
                    A_id2name=self.id2name)
Z
zhumanyu 已提交
343 344

            if self.cfg.save_checkpoints:
C
ceci3 已提交
345 346
                utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
                utility.checkpoints(epoch_id, self.cfg, dis_trainer, "net_D")
u010070587's avatar
u010070587 已提交
347 348 349 350 351 352 353 354 355 356 357 358 359
        if self.cfg.enable_ce:
            device_num = fluid.core.get_cuda_device_count(
            ) if self.cfg.use_gpu else 1
            print("kpis\tpix2pix_g_loss_gan_card{}\t{}".format(device_num,
                                                               g_loss_gan[0]))
            print("kpis\tpix2pix_g_loss_l1_card{}\t{}".format(device_num,
                                                              g_loss_l1[0]))
            print("kpis\tpix2pix_d_loss_real_card{}\t{}".format(device_num,
                                                                d_loss_real[0]))
            print("kpis\tpix2pix_d_loss_fake_card{}\t{}".format(device_num,
                                                                d_loss_fake[0]))
            print("kpis\tpix2pix_Batch_time_cost_card{}\t{}".format(device_num,
                                                                    batch_time))