CycleGAN.py 18.7 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from network.CycleGAN_network import CycleGAN_model
from util import utility
20
from util import timer
L
lvmengsi 已提交
21
import paddle.fluid as fluid
22
from paddle.fluid import profiler
L
lvmengsi 已提交
23
import paddle
L
lvmengsi 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
import sys
import time

lambda_A = 10.0
lambda_B = 10.0
lambda_identity = 0.5


class GTrainer():
    def __init__(self, input_A, input_B, cfg, step_per_epoch):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
            model = CycleGAN_model()
            self.fake_B = model.network_G(input_A, name="GA", cfg=cfg)
            self.fake_A = model.network_G(input_B, name="GB", cfg=cfg)
            self.cyc_A = model.network_G(self.fake_B, name="GB", cfg=cfg)
            self.cyc_B = model.network_G(self.fake_A, name="GA", cfg=cfg)

            self.infer_program = self.program.clone()
            # Cycle Loss
            diff_A = fluid.layers.abs(
                fluid.layers.elementwise_sub(
                    x=input_A, y=self.cyc_A))
            diff_B = fluid.layers.abs(
                fluid.layers.elementwise_sub(
                    x=input_B, y=self.cyc_B))
            self.cyc_A_loss = fluid.layers.reduce_mean(diff_A) * lambda_A
            self.cyc_B_loss = fluid.layers.reduce_mean(diff_B) * lambda_B
            self.cyc_loss = self.cyc_A_loss + self.cyc_B_loss
            # GAN Loss D_A(G_A(A))
            self.fake_rec_A = model.network_D(self.fake_B, name="DA", cfg=cfg)
            self.G_A = fluid.layers.reduce_mean(
                fluid.layers.square(self.fake_rec_A - 1))
            # GAN Loss D_B(G_B(B))
            self.fake_rec_B = model.network_D(self.fake_A, name="DB", cfg=cfg)
            self.G_B = fluid.layers.reduce_mean(
                fluid.layers.square(self.fake_rec_B - 1))
            self.G = self.G_A + self.G_B
            # Identity Loss G_A
            self.idt_A = model.network_G(input_B, name="GA", cfg=cfg)
            self.idt_loss_A = fluid.layers.reduce_mean(
                fluid.layers.abs(
                    fluid.layers.elementwise_sub(
                        x=input_B, y=self.idt_A))) * lambda_B * lambda_identity
            # Identity Loss G_B
            self.idt_B = model.network_G(input_A, name="GB", cfg=cfg)
            self.idt_loss_B = fluid.layers.reduce_mean(
                fluid.layers.abs(
                    fluid.layers.elementwise_sub(
                        x=input_A, y=self.idt_B))) * lambda_A * lambda_identity

            self.idt_loss = fluid.layers.elementwise_add(self.idt_loss_A,
                                                         self.idt_loss_B)
            self.g_loss = self.cyc_loss + self.G + self.idt_loss

            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and (var.name.startswith("GA") or
                                                   var.name.startswith("GB")):
                    vars.append(var.name)
            self.param = vars
            lr = cfg.learning_rate
L
lvmengsi 已提交
86 87 88 89 90 91 92 93
            if cfg.epoch <= 100:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=lr, beta1=0.5, beta2=0.999, name="net_G")
            else:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=fluid.layers.piecewise_decay(
                        boundaries=[99 * step_per_epoch] + [
                            x * step_per_epoch
94
                            for x in range(100, cfg.epoch - 1)
L
lvmengsi 已提交
95 96 97
                        ],
                        values=[lr] + [
                            lr * (1.0 - (x - 99.0) / 101.0)
98
                            for x in range(100, cfg.epoch)
L
lvmengsi 已提交
99 100 101 102
                        ]),
                    beta1=0.5,
                    beta2=0.999,
                    name="net_G")
L
lvmengsi 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
            optimizer.minimize(self.g_loss, parameter_list=vars)


class DATrainer():
    def __init__(self, input_B, fake_pool_B, cfg, step_per_epoch):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
            model = CycleGAN_model()
            self.rec_B = model.network_D(input_B, name="DA", cfg=cfg)
            self.fake_pool_rec_B = model.network_D(
                fake_pool_B, name="DA", cfg=cfg)
            self.d_loss_A = (fluid.layers.square(self.fake_pool_rec_B) +
                             fluid.layers.square(self.rec_B - 1)) / 2.0
            self.d_loss_A = fluid.layers.reduce_mean(self.d_loss_A)

            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and var.name.startswith("DA"):
                    vars.append(var.name)

            self.param = vars
            lr = cfg.learning_rate
L
lvmengsi 已提交
125 126 127 128 129 130 131 132
            if cfg.epoch <= 100:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=lr, beta1=0.5, beta2=0.999, name="net_DA")
            else:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=fluid.layers.piecewise_decay(
                        boundaries=[99 * step_per_epoch] + [
                            x * step_per_epoch
133
                            for x in range(100, cfg.epoch - 1)
L
lvmengsi 已提交
134 135 136
                        ],
                        values=[lr] + [
                            lr * (1.0 - (x - 99.0) / 101.0)
137
                            for x in range(100, cfg.epoch)
L
lvmengsi 已提交
138 139 140 141
                        ]),
                    beta1=0.5,
                    beta2=0.999,
                    name="net_DA")
L
lvmengsi 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162

            optimizer.minimize(self.d_loss_A, parameter_list=vars)


class DBTrainer():
    def __init__(self, input_A, fake_pool_A, cfg, step_per_epoch):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
            model = CycleGAN_model()
            self.rec_A = model.network_D(input_A, name="DB", cfg=cfg)
            self.fake_pool_rec_A = model.network_D(
                fake_pool_A, name="DB", cfg=cfg)
            self.d_loss_B = (fluid.layers.square(self.fake_pool_rec_A) +
                             fluid.layers.square(self.rec_A - 1)) / 2.0
            self.d_loss_B = fluid.layers.reduce_mean(self.d_loss_B)
            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and var.name.startswith("DB"):
                    vars.append(var.name)
            self.param = vars
            lr = 0.0002
L
lvmengsi 已提交
163 164 165 166 167 168 169 170
            if cfg.epoch <= 100:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=lr, beta1=0.5, beta2=0.999, name="net_DA")
            else:
                optimizer = fluid.optimizer.Adam(
                    learning_rate=fluid.layers.piecewise_decay(
                        boundaries=[99 * step_per_epoch] + [
                            x * step_per_epoch
171
                            for x in range(100, cfg.epoch - 1)
L
lvmengsi 已提交
172 173 174
                        ],
                        values=[lr] + [
                            lr * (1.0 - (x - 99.0) / 101.0)
175
                            for x in range(100, cfg.epoch)
L
lvmengsi 已提交
176 177 178 179
                        ]),
                    beta1=0.5,
                    beta2=0.999,
                    name="net_DB")
L
lvmengsi 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
            optimizer.minimize(self.d_loss_B, parameter_list=vars)


class CycleGAN(object):
    def add_special_args(self, parser):
        parser.add_argument(
            '--net_G',
            type=str,
            default="resnet_9block",
            help="Choose the CycleGAN generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]"
        )
        parser.add_argument(
            '--net_D',
            type=str,
            default="basic",
            help="Choose the CycleGAN discriminator's network, choose in [basic|nlayers|pixel]"
        )
        parser.add_argument(
            '--d_nlayers',
            type=int,
            default=3,
            help="only used when CycleGAN discriminator is nlayers")
u010070587's avatar
u010070587 已提交
202 203 204 205
        parser.add_argument(
            '--enable_ce',
            action='store_true',
            help="if set, run the tasks with continuous evaluation logs")
L
lvmengsi 已提交
206 207 208 209 210 211 212 213
        return parser

    def __init__(self,
                 cfg=None,
                 A_reader=None,
                 B_reader=None,
                 A_test_reader=None,
                 B_test_reader=None,
L
lvmengsi 已提交
214 215 216
                 batch_num=1,
                 A_id2name=None,
                 B_id2name=None):
L
lvmengsi 已提交
217 218 219 220 221 222
        self.cfg = cfg
        self.A_reader = A_reader
        self.B_reader = B_reader
        self.A_test_reader = A_test_reader
        self.B_test_reader = B_test_reader
        self.batch_num = batch_num
L
lvmengsi 已提交
223 224
        self.A_id2name = A_id2name
        self.B_id2name = B_id2name
L
lvmengsi 已提交
225 226

    def build_model(self):
L
lvmengsi 已提交
227
        data_shape = [None, 3, self.cfg.crop_size, self.cfg.crop_size]
L
lvmengsi 已提交
228

L
lvmengsi 已提交
229 230 231
        input_A = fluid.data(name='input_A', shape=data_shape, dtype='float32')
        input_B = fluid.data(name='input_B', shape=data_shape, dtype='float32')
        fake_pool_A = fluid.data(
L
lvmengsi 已提交
232
            name='fake_pool_A', shape=data_shape, dtype='float32')
L
lvmengsi 已提交
233
        fake_pool_B = fluid.data(
L
lvmengsi 已提交
234
            name='fake_pool_B', shape=data_shape, dtype='float32')
u010070587's avatar
u010070587 已提交
235 236 237
        # used for continuous evaluation
        if self.cfg.enable_ce:
            fluid.default_startup_program().random_seed = 90
L
lvmengsi 已提交
238

C
ceci3 已提交
239
        A_loader = fluid.io.DataLoader.from_generator(
L
lvmengsi 已提交
240 241 242 243 244
            feed_list=[input_A],
            capacity=4,
            iterable=True,
            use_double_buffer=True)

C
ceci3 已提交
245
        B_loader = fluid.io.DataLoader.from_generator(
L
lvmengsi 已提交
246 247 248 249 250
            feed_list=[input_B],
            capacity=4,
            iterable=True,
            use_double_buffer=True)

L
lvmengsi 已提交
251 252 253 254 255 256
        gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num)
        d_A_trainer = DATrainer(input_B, fake_pool_B, self.cfg, self.batch_num)
        d_B_trainer = DBTrainer(input_A, fake_pool_A, self.cfg, self.batch_num)

        # prepare environment
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
L
lvmengsi 已提交
257

C
ceci3 已提交
258
        A_loader.set_batch_generator(
L
lvmengsi 已提交
259 260 261
            self.A_reader,
            places=fluid.cuda_places()
            if self.cfg.use_gpu else fluid.cpu_places())
C
ceci3 已提交
262
        B_loader.set_batch_generator(
L
lvmengsi 已提交
263 264 265
            self.B_reader,
            places=fluid.cuda_places()
            if self.cfg.use_gpu else fluid.cpu_places())
L
lvmengsi 已提交
266

L
lvmengsi 已提交
267 268 269 270 271 272 273
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        A_pool = utility.ImagePool()
        B_pool = utility.ImagePool()

        if self.cfg.init_model:
C
ceci3 已提交
274 275 276
            utility.init_checkpoints(self.cfg, gen_trainer, "net_G")
            utility.init_checkpoints(self.cfg, d_A_trainer, "net_DA")
            utility.init_checkpoints(self.cfg, d_B_trainer, "net_DB")
L
lvmengsi 已提交
277 278 279

        ### memory optim
        build_strategy = fluid.BuildStrategy()
L
lvmengsi 已提交
280
        build_strategy.enable_inplace = True
L
lvmengsi 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294

        gen_trainer_program = fluid.CompiledProgram(
            gen_trainer.program).with_data_parallel(
                loss_name=gen_trainer.g_loss.name,
                build_strategy=build_strategy)
        d_A_trainer_program = fluid.CompiledProgram(
            d_A_trainer.program).with_data_parallel(
                loss_name=d_A_trainer.d_loss_A.name,
                build_strategy=build_strategy)
        d_B_trainer_program = fluid.CompiledProgram(
            d_B_trainer.program).with_data_parallel(
                loss_name=d_B_trainer.d_loss_B.name,
                build_strategy=build_strategy)

295
        total_train_batch = 0  # NOTE :used for benchmark
296 297
        reader_cost_averager = timer.TimeAverager()
        batch_cost_averager = timer.TimeAverager()
L
lvmengsi 已提交
298 299
        for epoch_id in range(self.cfg.epoch):
            batch_id = 0
300
            batch_start = time.time()
C
ceci3 已提交
301
            for data_A, data_B in zip(A_loader(), B_loader()):
302
                if self.cfg.max_iter and total_train_batch == self.cfg.max_iter:  # used for benchmark
303
                    return
304 305
                reader_cost_averager.record(time.time() - batch_start)

L
lvmengsi 已提交
306 307
                tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B']
                ## optimize the g_A network
L
lvmengsi 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
                g_A_loss, g_A_cyc_loss, g_A_idt_loss, g_B_loss, g_B_cyc_loss,\
                g_B_idt_loss, fake_A_tmp, fake_B_tmp = exe.run(
                    gen_trainer_program,
                    fetch_list=[
                        gen_trainer.G_A, gen_trainer.cyc_A_loss,
                        gen_trainer.idt_loss_A, gen_trainer.G_B,
                        gen_trainer.cyc_B_loss, gen_trainer.idt_loss_B,
                        gen_trainer.fake_A, gen_trainer.fake_B
                    ],
                    feed={"input_A": tensor_A,
                          "input_B": tensor_B})

                fake_pool_B = B_pool.pool_image(fake_B_tmp)
                fake_pool_A = A_pool.pool_image(fake_A_tmp)

C
ceci3 已提交
323 324 325 326
                if self.cfg.enable_ce:
                    fake_pool_B = fake_B_tmp
                    fake_pool_A = fake_A_tmp

L
lvmengsi 已提交
327 328 329 330 331 332 333 334 335 336 337 338 339 340
                # optimize the d_A network
                d_A_loss = exe.run(
                    d_A_trainer_program,
                    fetch_list=[d_A_trainer.d_loss_A],
                    feed={"input_B": tensor_B,
                          "fake_pool_B": fake_pool_B})[0]

                # optimize the d_B network
                d_B_loss = exe.run(
                    d_B_trainer_program,
                    fetch_list=[d_B_trainer.d_loss_B],
                    feed={"input_A": tensor_A,
                          "fake_pool_A": fake_pool_A})[0]

341 342
                batch_cost_averager.record(
                    time.time() - batch_start, num_samples=self.cfg.batch_size)
L
lvmengsi 已提交
343 344
                if batch_id % self.cfg.print_freq == 0:
                    print("epoch{}: batch{}: \n\
345 346 347
                         d_A_loss: {:.5f}; g_A_loss: {:.5f}; g_A_cyc_loss: {:.5f}; g_A_idt_loss: {:.5f}; \n\
                         d_B_loss: {:.5f}; g_B_loss: {:.5f}; g_B_cyc_loss: {:.5f}; g_B_idt_loss: {:.5f}; \n\
                         batch_cost: {:.5f} sec, reader_cost: {:.5f} sec, ips: {:.5f} images/sec"
348 349 350
                          .format(epoch_id, batch_id, d_A_loss[0], g_A_loss[
                              0], g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0],
                                  g_B_loss[0], g_B_cyc_loss[0], g_B_idt_loss[0],
351
                                  batch_cost_averager.get_average(),
352
                                  reader_cost_averager.get_average(),
353
                                  batch_cost_averager.get_ips_average()))
354 355
                    reader_cost_averager.reset()
                    batch_cost_averager.reset()
L
lvmengsi 已提交
356 357 358

                sys.stdout.flush()
                batch_id += 1
359
                total_train_batch += 1  # used for benchmark
360 361
                batch_start = time.time()

362 363 364 365 366
                # profiler tools
                if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
                    profiler.reset_profiler()
                elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
                    return
367

u010070587's avatar
u010070587 已提交
368 369 370
                # used for continuous evaluation
                if self.cfg.enable_ce and batch_id == 10:
                    break
L
lvmengsi 已提交
371 372

            if self.cfg.run_test:
L
lvmengsi 已提交
373 374 375 376
                A_image_name = fluid.data(
                    name='A_image_name', shape=[None, 1], dtype='int32')
                B_image_name = fluid.data(
                    name='B_image_name', shape=[None, 1], dtype='int32')
C
ceci3 已提交
377
                A_test_loader = fluid.io.DataLoader.from_generator(
L
lvmengsi 已提交
378 379 380 381 382
                    feed_list=[input_A, A_image_name],
                    capacity=4,
                    iterable=True,
                    use_double_buffer=True)

C
ceci3 已提交
383
                B_test_loader = fluid.io.DataLoader.from_generator(
L
lvmengsi 已提交
384 385 386 387 388
                    feed_list=[input_B, B_image_name],
                    capacity=4,
                    iterable=True,
                    use_double_buffer=True)

C
ceci3 已提交
389
                A_test_loader.set_batch_generator(
L
lvmengsi 已提交
390 391 392
                    self.A_test_reader,
                    places=fluid.cuda_places()
                    if self.cfg.use_gpu else fluid.cpu_places())
C
ceci3 已提交
393
                B_test_loader.set_batch_generator(
L
lvmengsi 已提交
394 395 396
                    self.B_test_reader,
                    places=fluid.cuda_places()
                    if self.cfg.use_gpu else fluid.cpu_places())
L
lvmengsi 已提交
397
                test_program = gen_trainer.infer_program
L
lvmengsi 已提交
398 399 400 401 402 403 404
                utility.save_test_image(
                    epoch_id,
                    self.cfg,
                    exe,
                    place,
                    test_program,
                    gen_trainer,
C
ceci3 已提交
405 406
                    A_test_loader,
                    B_test_loader,
L
lvmengsi 已提交
407 408
                    A_id2name=self.A_id2name,
                    B_id2name=self.B_id2name)
L
lvmengsi 已提交
409 410

            if self.cfg.save_checkpoints:
411 412 413
                utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
                utility.checkpoints(epoch_id, self.cfg, d_A_trainer, "net_DA")
                utility.checkpoints(epoch_id, self.cfg, d_B_trainer, "net_DB")
u010070587's avatar
u010070587 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436

        # used for continuous evaluation
        if self.cfg.enable_ce:
            device_num = fluid.core.get_cuda_device_count(
            ) if self.cfg.use_gpu else 1
            print("kpis\tcyclegan_g_A_loss_card{}\t{}".format(device_num,
                                                              g_A_loss[0]))
            print("kpis\tcyclegan_g_A_cyc_loss_card{}\t{}".format(
                device_num, g_A_cyc_loss[0]))
            print("kpis\tcyclegan_g_A_idt_loss_card{}\t{}".format(
                device_num, g_A_idt_loss[0]))
            print("kpis\tcyclegan_d_A_loss_card{}\t{}".format(device_num,
                                                              d_A_loss[0]))
            print("kpis\tcyclegan_g_B_loss_card{}\t{}".format(device_num,
                                                              g_B_loss[0]))
            print("kpis\tcyclegan_g_B_cyc_loss_card{}\t{}".format(
                device_num, g_B_cyc_loss[0]))
            print("kpis\tcyclegan_g_B_idt_loss_card{}\t{}".format(
                device_num, g_B_idt_loss[0]))
            print("kpis\tcyclegan_d_B_loss_card{}\t{}".format(device_num,
                                                              d_B_loss[0]))
            print("kpis\tcyclegan_Batch_time_cost_card{}\t{}".format(
                device_num, batch_time))