STGAN.py 19.4 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lvmengsi 已提交
14 15 16 17 18 19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from network.STGAN_network import STGAN_model
from util import utility
import paddle.fluid as fluid
H
hysunflower 已提交
20
from paddle.fluid import profiler
L
lvmengsi 已提交
21 22 23 24
import sys
import time
import copy
import numpy as np
L
Lv Mengsi 已提交
25
import ast
L
lvmengsi 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58


class GTrainer():
    def __init__(self, image_real, label_org, label_org_, label_trg, label_trg_,
                 cfg, step_per_epoch):
        self.program = fluid.default_main_program().clone()
        with fluid.program_guard(self.program):
            model = STGAN_model()
            self.fake_img, self.rec_img = model.network_G(
                image_real, label_org_, label_trg_, cfg, name="generator")
            self.fake_img.persistable = True
            self.rec_img.persistable = True
            self.infer_program = self.program.clone(for_test=True)
            self.g_loss_rec = fluid.layers.mean(
                fluid.layers.abs(
                    fluid.layers.elementwise_sub(
                        x=image_real, y=self.rec_img)))
            self.pred_fake, self.cls_fake = model.network_D(
                self.fake_img, cfg, name="discriminator")
            #wgan
            if cfg.gan_mode == "wgan":
                self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake)
            #lsgan
            elif cfg.gan_mode == "lsgan":
                ones = fluid.layers.fill_constant_batch_size_like(
                    input=self.pred_fake,
                    shape=self.pred_fake.shape,
                    value=1.0,
                    dtype='float32')
                self.g_loss_fake = fluid.layers.mean(
                    fluid.layers.square(
                        fluid.layers.elementwise_sub(
                            x=self.pred_fake, y=ones)))
L
lvmengsi 已提交
59 60 61
            else:
                raise NotImplementedError("gan_mode {} is not support!".format(
                    cfg.gan_mode))
L
lvmengsi 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94

            self.g_loss_cls = fluid.layers.mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_fake,
                                                               label_trg))
            self.g_loss = self.g_loss_fake + cfg.lambda_rec * self.g_loss_rec + cfg.lambda_cls * self.g_loss_cls
            self.g_loss_fake.persistable = True
            self.g_loss_rec.persistable = True
            self.g_loss_cls.persistable = True
            lr = cfg.g_lr
            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and var.name.startswith(
                        "generator"):
                    vars.append(var.name)
            self.param = vars
            optimizer = fluid.optimizer.Adam(
                learning_rate=fluid.layers.piecewise_decay(
                    boundaries=[99 * step_per_epoch], values=[lr, lr * 0.1]),
                beta1=0.5,
                beta2=0.999,
                name="net_G")

            optimizer.minimize(self.g_loss, parameter_list=vars)


class DTrainer():
    def __init__(self, image_real, label_org, label_org_, label_trg, label_trg_,
                 cfg, step_per_epoch):
        self.program = fluid.default_main_program().clone()
        lr = cfg.d_lr
        with fluid.program_guard(self.program):
            model = STGAN_model()
            self.fake_img, _ = model.network_G(
L
lvmengsi 已提交
95
                image_real, label_org_, label_trg_, cfg, name="generator")
L
lvmengsi 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
            self.pred_real, self.cls_real = model.network_D(
                image_real, cfg, name="discriminator")
            self.pred_real.persistable = True
            self.cls_real.persistable = True
            self.pred_fake, _ = model.network_D(
                self.fake_img, cfg, name="discriminator")
            self.d_loss_cls = fluid.layers.mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_real,
                                                               label_org))
            #wgan
            if cfg.gan_mode == "wgan":
                self.d_loss_fake = fluid.layers.reduce_mean(self.pred_fake)
                self.d_loss_real = -1 * fluid.layers.reduce_mean(self.pred_real)
                self.d_loss_gp = self.gradient_penalty(
                    model.network_D,
L
lvmengsi 已提交
111
                    image_real,
L
lvmengsi 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
                    self.fake_img,
                    cfg=cfg,
                    name="discriminator")
                self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
            #lsgan
            elif cfg.gan_mode == "lsgan":
                ones = fluid.layers.fill_constant_batch_size_like(
                    input=self.pred_real,
                    shape=self.pred_real.shape,
                    value=1.0,
                    dtype='float32')
                self.d_loss_real = fluid.layers.mean(
                    fluid.layers.square(
                        fluid.layers.elementwise_sub(
                            x=self.pred_real, y=ones)))
                self.d_loss_fake = fluid.layers.mean(
                    fluid.layers.square(x=self.pred_fake))
L
lvmengsi 已提交
129 130 131 132 133 134 135
                self.d_loss_gp = self.gradient_penalty(
                    model.network_D,
                    image_real,
                    None,
                    cfg=cfg,
                    name="discriminator")
                self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
L
lvmengsi 已提交
136 137 138
            else:
                raise NotImplementedError("gan_mode {} is not support!".format(
                    cfg.gan_mode))
L
lvmengsi 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

            self.d_loss_real.persistable = True
            self.d_loss_fake.persistable = True
            self.d_loss.persistable = True
            self.d_loss_cls.persistable = True
            self.d_loss_gp.persistable = True
            vars = []
            for var in self.program.list_vars():
                if fluid.io.is_parameter(var) and (
                        var.name.startswith("discriminator")):
                    vars.append(var.name)
            self.param = vars

            optimizer = fluid.optimizer.Adam(
                learning_rate=fluid.layers.piecewise_decay(
                    boundaries=[99 * step_per_epoch],
                    values=[lr, lr * 0.1], ),
                beta1=0.5,
                beta2=0.999,
                name="net_D")

            optimizer.minimize(self.d_loss, parameter_list=vars)
L
lvmengsi 已提交
161 162
            f = open('G_program.txt', 'w')
            print(self.program, file=f)
L
lvmengsi 已提交
163 164 165

    def gradient_penalty(self, f, real, fake=None, cfg=None, name=None):
        def _interpolate(a, b=None):
L
lvmengsi 已提交
166
            if b is None:
u010070587's avatar
u010070587 已提交
167 168 169 170 171 172 173
                if cfg.enable_ce:
                   beta = fluid.layers.uniform_random_batch_size_like(
                       input=a, shape=a.shape, min=0.0, max=1.0, seed=1)
                else:
                   beta = fluid.layers.uniform_random_batch_size_like(
                       input=a, shape=a.shape, min=0.0, max=1.0)
                   
L
lvmengsi 已提交
174
                mean = fluid.layers.reduce_mean(
L
lvmengsi 已提交
175
                    a, dim=list(range(len(a.shape))), keep_dim=True)
L
lvmengsi 已提交
176 177 178
                input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0)
                var = fluid.layers.reduce_mean(
                    fluid.layers.square(input_sub_mean),
L
lvmengsi 已提交
179
                    dim=list(range(len(a.shape))),
L
lvmengsi 已提交
180 181
                    keep_dim=True)
                b = beta * fluid.layers.sqrt(var) * 0.5 + a
L
lvmengsi 已提交
182
            shape = [a.shape[0]]
u010070587's avatar
u010070587 已提交
183 184 185 186 187 188 189
            if cfg.enable_ce:
                alpha = fluid.layers.uniform_random_batch_size_like(
                    input=a, shape=shape, min=0.0, max=1.0, seed=1)
            else:    
                alpha = fluid.layers.uniform_random_batch_size_like(
                    input=a, shape=shape, min=0.0, max=1.0)

L
Lv Mengsi 已提交
190
            inner = fluid.layers.elementwise_mul((b-a), alpha, axis=0) + a
L
lvmengsi 已提交
191 192 193 194 195 196 197 198 199 200 201 202
            return inner

        x = _interpolate(real, fake)

        pred, _ = f(x, cfg=cfg, name=name)
        if isinstance(pred, tuple):
            pred = pred[0]
        vars = []
        for var in fluid.default_main_program().list_vars():
            if fluid.io.is_parameter(var) and var.name.startswith(
                    "discriminator"):
                vars.append(var.name)
L
lvmengsi 已提交
203
        grad = fluid.gradients(pred, x, no_grad_set=vars)[0]
L
lvmengsi 已提交
204 205 206
        grad_shape = grad.shape
        grad = fluid.layers.reshape(
            grad, [-1, grad_shape[1] * grad_shape[2] * grad_shape[3]])
L
lvmengsi 已提交
207
        epsilon = 1e-16
L
lvmengsi 已提交
208 209
        norm = fluid.layers.sqrt(
            fluid.layers.reduce_sum(
L
lvmengsi 已提交
210
                fluid.layers.square(grad), dim=1) + epsilon)
L
lvmengsi 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
        gp = fluid.layers.reduce_mean(fluid.layers.square(norm - 1.0))
        return gp


class STGAN(object):
    def add_special_args(self, parser):
        parser.add_argument(
            '--g_lr',
            type=float,
            default=0.0002,
            help="the base learning rate of generator")
        parser.add_argument(
            '--d_lr',
            type=float,
            default=0.0002,
            help="the base learning rate of discriminator")
        parser.add_argument(
            '--c_dim',
            type=int,
            default=13,
            help="the number of attributes we selected")
        parser.add_argument(
            '--d_fc_dim',
            type=int,
            default=1024,
            help="the base fc dim in discriminator")
        parser.add_argument(
L
Lv Mengsi 已提交
238
            '--use_gru', type=ast.literal_eval, default=True, help="whether to use GRU")
L
lvmengsi 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
        parser.add_argument(
            '--lambda_cls',
            type=float,
            default=10.0,
            help="the coefficient of classification")
        parser.add_argument(
            '--lambda_rec',
            type=float,
            default=100.0,
            help="the coefficient of refactor")
        parser.add_argument(
            '--thres_int',
            type=float,
            default=0.5,
            help="thresh change of attributes")
        parser.add_argument(
            '--lambda_gp',
            type=float,
            default=10.0,
            help="the coefficient of gradient penalty")
        parser.add_argument(
            '--n_samples', type=int, default=16, help="batch size when testing")
        parser.add_argument(
            '--selected_attrs',
            type=str,
            default="Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young",
            help="the attributes we selected to change")
        parser.add_argument(
            '--n_layers',
            type=int,
            default=5,
            help="default layers in generotor")
        parser.add_argument(
            '--gru_n_layers',
            type=int,
            default=4,
            help="default layers of GRU in generotor")
L
lvmengsi 已提交
276 277 278 279 280 281
        parser.add_argument(
            '--dis_norm',
            type=str,
            default=None,
            help="the normalization in discriminator, choose in [None, instance_norm]"
        )
u010070587's avatar
u010070587 已提交
282 283 284 285
        parser.add_argument(
            '--enable_ce',
            action='store_true',
            help="if set, run the tasks with continuous evaluation logs")
L
lvmengsi 已提交
286 287 288 289 290 291
        return parser

    def __init__(self,
                 cfg=None,
                 train_reader=None,
                 test_reader=None,
L
lvmengsi 已提交
292 293
                 batch_num=1,
                 id2name=None):
L
lvmengsi 已提交
294 295 296 297 298 299
        self.cfg = cfg
        self.train_reader = train_reader
        self.test_reader = test_reader
        self.batch_num = batch_num

    def build_model(self):
L
lvmengsi 已提交
300
        data_shape = [None, 3, self.cfg.image_size, self.cfg.image_size]
L
lvmengsi 已提交
301

L
lvmengsi 已提交
302
        image_real = fluid.data(
L
lvmengsi 已提交
303
            name='image_real', shape=data_shape, dtype='float32')
L
lvmengsi 已提交
304 305 306 307 308 309 310 311
        label_org = fluid.data(
            name='label_org', shape=[None, self.cfg.c_dim], dtype='float32')
        label_trg = fluid.data(
            name='label_trg', shape=[None, self.cfg.c_dim], dtype='float32')
        label_org_ = fluid.data(
            name='label_org_', shape=[None, self.cfg.c_dim], dtype='float32')
        label_trg_ = fluid.data(
            name='label_trg_', shape=[None, self.cfg.c_dim], dtype='float32')
u010070587's avatar
u010070587 已提交
312 313 314
        # used for continuous evaluation        
        if self.cfg.enable_ce:
            fluid.default_startup_program().random_seed = 90
L
lvmengsi 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327

        test_gen_trainer = GTrainer(image_real, label_org, label_org_,
                                    label_trg, label_trg_, self.cfg,
                                    self.batch_num)

        py_reader = fluid.io.PyReader(
            feed_list=[image_real, label_org, label_trg],
            capacity=64,
            iterable=True,
            use_double_buffer=True)
        label_org_ = (label_org * 2.0 - 1.0) * self.cfg.thres_int
        label_trg_ = (label_trg * 2.0 - 1.0) * self.cfg.thres_int

L
lvmengsi 已提交
328 329 330 331 332 333 334
        gen_trainer = GTrainer(image_real, label_org, label_org_, label_trg,
                               label_trg_, self.cfg, self.batch_num)
        dis_trainer = DTrainer(image_real, label_org, label_org_, label_trg,
                               label_trg_, self.cfg, self.batch_num)

        # prepare environment
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
L
lvmengsi 已提交
335 336 337 338
        py_reader.decorate_batch_generator(
            self.train_reader,
            places=fluid.cuda_places()
            if self.cfg.use_gpu else fluid.cpu_places())
L
lvmengsi 已提交
339

L
lvmengsi 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        if self.cfg.init_model:
            utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G")
            utility.init_checkpoints(self.cfg, exe, dis_trainer, "net_D")

        ### memory optim
        build_strategy = fluid.BuildStrategy()

        gen_trainer_program = fluid.CompiledProgram(
            gen_trainer.program).with_data_parallel(
                loss_name=gen_trainer.g_loss.name,
                build_strategy=build_strategy)
        dis_trainer_program = fluid.CompiledProgram(
            dis_trainer.program).with_data_parallel(
                loss_name=dis_trainer.d_loss.name,
                build_strategy=build_strategy)
u010070587's avatar
u010070587 已提交
358 359 360 361 362
        # used for continuous evaluation        
        if self.cfg.enable_ce:
            gen_trainer_program.random_seed = 90
            dis_trainer_program.random_seed = 90
 
L
lvmengsi 已提交
363 364
        t_time = 0

H
hysunflower 已提交
365 366
        total_train_batch = 0  # used for benchmark

L
lvmengsi 已提交
367 368
        for epoch_id in range(self.cfg.epoch):
            batch_id = 0
L
lvmengsi 已提交
369
            for data in py_reader():
H
hysunflower 已提交
370 371
                if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
                    return
L
lvmengsi 已提交
372 373
                s_time = time.time()
                # optimize the discriminator network
L
lvmengsi 已提交
374 375 376 377 378 379 380 381 382 383 384
                fetches = [
                    dis_trainer.d_loss.name,
                    dis_trainer.d_loss_real.name,
                    dis_trainer.d_loss_fake.name,
                    dis_trainer.d_loss_cls.name,
                    dis_trainer.d_loss_gp.name,
                ]
                d_loss, d_loss_real, d_loss_fake, d_loss_cls, d_loss_gp, = exe.run(
                    dis_trainer_program, fetch_list=fetches, feed=data)
                if (batch_id + 1) % self.cfg.num_discriminator_time == 0:
                    # optimize the generator network
L
lvmengsi 已提交
385 386 387 388 389
                    d_fetches = [
                        gen_trainer.g_loss_fake.name,
                        gen_trainer.g_loss_rec.name, gen_trainer.g_loss_cls.name
                    ]
                    g_loss_fake, g_loss_rec, g_loss_cls = exe.run(
L
lvmengsi 已提交
390
                        gen_trainer_program, fetch_list=d_fetches, feed=data)
L
lvmengsi 已提交
391 392 393 394
                    print("epoch{}: batch{}: \n\
                         g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}"
                          .format(epoch_id, batch_id, g_loss_fake[0],
                                  g_loss_rec[0], g_loss_cls[0]))
L
lvmengsi 已提交
395 396 397 398 399 400 401 402
                batch_time = time.time() - s_time
                t_time += batch_time
                if (batch_id + 1) % self.cfg.print_freq == 0:
                    print("epoch{}: batch{}:  \n\
                         d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
                         Batch_time_cost: {}".format(epoch_id, batch_id, d_loss[
                        0], d_loss_real[0], d_loss_fake[0], d_loss_cls[0],
                                                     d_loss_gp[0], batch_time))
L
lvmengsi 已提交
403 404
                sys.stdout.flush()
                batch_id += 1
u010070587's avatar
u010070587 已提交
405 406 407
                if self.cfg.enable_ce and batch_id == 100:
                   break

H
hysunflower 已提交
408 409 410 411 412 413
                total_train_batch += 1  # used for benchmark
                # profiler tools
                if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
                    profiler.reset_profiler()
                elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
                    return
L
lvmengsi 已提交
414 415

            if self.cfg.run_test:
L
lvmengsi 已提交
416
                image_name = fluid.data(
L
lvmengsi 已提交
417
                    name='image_name',
L
lvmengsi 已提交
418
                    shape=[None, self.cfg.n_samples],
L
lvmengsi 已提交
419 420 421 422 423 424 425
                    dtype='int32')
                test_py_reader = fluid.io.PyReader(
                    feed_list=[image_real, label_org, label_trg, image_name],
                    capacity=32,
                    iterable=True,
                    use_double_buffer=True)
                test_py_reader.decorate_batch_generator(
L
lvmengsi 已提交
426 427 428
                    self.test_reader,
                    places=fluid.cuda_places()
                    if self.cfg.use_gpu else fluid.cpu_places())
L
lvmengsi 已提交
429
                test_program = test_gen_trainer.infer_program
L
lvmengsi 已提交
430
                utility.save_test_image(epoch_id, self.cfg, exe, place,
L
lvmengsi 已提交
431 432
                                        test_program, test_gen_trainer,
                                        test_py_reader)
L
lvmengsi 已提交
433 434 435 436 437 438

            if self.cfg.save_checkpoints:
                utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
                                    "net_G")
                utility.checkpoints(epoch_id, self.cfg, exe, dis_trainer,
                                    "net_D")
u010070587's avatar
u010070587 已提交
439 440 441 442 443 444 445 446 447 448 449 450
            # used for continuous evaluation
            if self.cfg.enable_ce:
                device_num = fluid.core.get_cuda_device_count() if self.cfg.use_gpu else 1
                print("kpis\tstgan_g_loss_fake_card{}\t{}".format(device_num, g_loss_fake[0]))
                print("kpis\tstgan_g_loss_rec_card{}\t{}".format(device_num, g_loss_rec[0]))
                print("kpis\tstgan_g_loss_cls_card{}\t{}".format(device_num, g_loss_cls[0]))
                print("kpis\tstgan_d_loss_card{}\t{}".format(device_num, d_loss[0]))
                print("kpis\tstgan_d_loss_real_card{}\t{}".format(device_num, d_loss_real[0]))
                print("kpis\tstgan_d_loss_fake_card{}\t{}".format(device_num,d_loss_fake[0]))
                print("kpis\tstgan_d_loss_cls_card{}\t{}".format(device_num, d_loss_cls[0]))
                print("kpis\tstgan_d_loss_gp_card{}\t{}".format(device_num,d_loss_gp[0]))
                print("kpis\tstgan_Batch_time_cost_card{}\t{}".format(device_num,batch_time))