# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function from network.STGAN_network import STGAN_model from util import utility from util import timer import paddle.fluid as fluid from paddle.fluid import profiler import sys import time import copy import numpy as np import ast class GTrainer(): def __init__(self, image_real, label_org, label_org_, label_trg, label_trg_, cfg, step_per_epoch): self.program = fluid.default_main_program().clone() with fluid.program_guard(self.program): model = STGAN_model() self.fake_img, self.rec_img = model.network_G( image_real, label_org_, label_trg_, cfg, name="generator") self.infer_program = self.program.clone(for_test=True) self.g_loss_rec = fluid.layers.mean( fluid.layers.abs( fluid.layers.elementwise_sub( x=image_real, y=self.rec_img))) self.pred_fake, self.cls_fake = model.network_D( self.fake_img, cfg, name="discriminator") #wgan if cfg.gan_mode == "wgan": self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake) #lsgan elif cfg.gan_mode == "lsgan": fake_shape = fluid.layers.shape(self.pred_fake) ones = fluid.layers.fill_constant( shape=fake_shape, value=1.0, dtype='float32') self.g_loss_fake = fluid.layers.mean( fluid.layers.square( fluid.layers.elementwise_sub( x=self.pred_fake, y=ones))) else: raise NotImplementedError("gan_mode {} is not support!".format( cfg.gan_mode)) self.g_loss_cls = fluid.layers.mean( fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_fake, label_trg)) self.g_loss = self.g_loss_fake + cfg.lambda_rec * self.g_loss_rec + cfg.lambda_cls * self.g_loss_cls lr = cfg.g_lr vars = [] for var in self.program.list_vars(): if fluid.io.is_parameter(var) and var.name.startswith( "generator"): vars.append(var.name) self.param = vars optimizer = fluid.optimizer.Adam( learning_rate=fluid.layers.piecewise_decay( boundaries=[99 * step_per_epoch], values=[lr, lr * 0.1]), beta1=0.5, beta2=0.999, name="net_G") optimizer.minimize(self.g_loss, parameter_list=vars) class DTrainer(): def __init__(self, image_real, label_org, label_org_, label_trg, label_trg_, cfg, step_per_epoch): self.program = fluid.default_main_program().clone() lr = cfg.d_lr with fluid.program_guard(self.program): model = STGAN_model() self.fake_img, _ = model.network_G( image_real, label_org_, label_trg_, cfg, name="generator") self.pred_real, self.cls_real = model.network_D( image_real, cfg, name="discriminator") self.pred_fake, _ = model.network_D( self.fake_img, cfg, name="discriminator") self.d_loss_cls = fluid.layers.mean( fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_real, label_org)) #wgan if cfg.gan_mode == "wgan": self.d_loss_fake = fluid.layers.reduce_mean(self.pred_fake) self.d_loss_real = -1 * fluid.layers.reduce_mean(self.pred_real) self.d_loss_gp = self.gradient_penalty( model.network_D, image_real, self.fake_img, cfg=cfg, name="discriminator") self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp #lsgan elif cfg.gan_mode == "lsgan": real_shape = fluid.layers.shape(self.pred_real) ones = fluid.layers.fill_constant( shape=real_shape, value=1.0, dtype='float32') self.d_loss_real = fluid.layers.mean( fluid.layers.square( fluid.layers.elementwise_sub( x=self.pred_real, y=ones))) self.d_loss_fake = fluid.layers.mean( fluid.layers.square(x=self.pred_fake)) self.d_loss_gp = self.gradient_penalty( model.network_D, image_real, None, cfg=cfg, name="discriminator") self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp else: raise NotImplementedError("gan_mode {} is not support!".format( cfg.gan_mode)) vars = [] for var in self.program.list_vars(): if fluid.io.is_parameter(var) and ( var.name.startswith("discriminator")): vars.append(var.name) self.param = vars optimizer = fluid.optimizer.Adam( learning_rate=fluid.layers.piecewise_decay( boundaries=[99 * step_per_epoch], values=[lr, lr * 0.1], ), beta1=0.5, beta2=0.999, name="net_D") optimizer.minimize(self.d_loss, parameter_list=vars) def gradient_penalty(self, f, real, fake=None, cfg=None, name=None): def _interpolate(a, b=None): a_shape = fluid.layers.shape(a) if b is None: if cfg.enable_ce: beta = fluid.layers.uniform_random( shape=a_shape, min=0.0, max=1.0, seed=1) else: beta = fluid.layers.uniform_random( shape=a_shape, min=0.0, max=1.0) mean = fluid.layers.reduce_mean( a, dim=list(range(len(a.shape)))) input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0) var = fluid.layers.reduce_mean( fluid.layers.square(input_sub_mean), dim=list(range(len(a.shape)))) b = beta * fluid.layers.sqrt(var) * 0.5 + a if cfg.enable_ce: alpha = fluid.layers.uniform_random( shape=a_shape[0], min=0.0, max=1.0, seed=1) else: alpha = fluid.layers.uniform_random( shape=a_shape[0], min=0.0, max=1.0) inner = fluid.layers.elementwise_mul((b - a), alpha, axis=0) + a return inner x = _interpolate(real, fake) pred, _ = f(x, cfg=cfg, name=name) if isinstance(pred, tuple): pred = pred[0] vars = [] for var in fluid.default_main_program().list_vars(): if fluid.io.is_parameter(var) and var.name.startswith( "discriminator"): vars.append(var.name) grad = fluid.gradients(pred, x, no_grad_set=vars)[0] grad_shape = grad.shape grad = fluid.layers.reshape( grad, [-1, grad_shape[1] * grad_shape[2] * grad_shape[3]]) epsilon = 1e-16 norm = fluid.layers.sqrt( fluid.layers.reduce_sum( fluid.layers.square(grad), dim=1) + epsilon) gp = fluid.layers.reduce_mean(fluid.layers.square(norm - 1.0)) return gp class STGAN(object): def add_special_args(self, parser): parser.add_argument( '--g_lr', type=float, default=0.0002, help="the base learning rate of generator") parser.add_argument( '--d_lr', type=float, default=0.0002, help="the base learning rate of discriminator") parser.add_argument( '--c_dim', type=int, default=13, help="the number of attributes we selected") parser.add_argument( '--d_fc_dim', type=int, default=1024, help="the base fc dim in discriminator") parser.add_argument( '--use_gru', type=ast.literal_eval, default=True, help="whether to use GRU") parser.add_argument( '--lambda_cls', type=float, default=10.0, help="the coefficient of classification") parser.add_argument( '--lambda_rec', type=float, default=100.0, help="the coefficient of refactor") parser.add_argument( '--thres_int', type=float, default=0.5, help="thresh change of attributes") parser.add_argument( '--lambda_gp', type=float, default=10.0, help="the coefficient of gradient penalty") parser.add_argument( '--n_samples', type=int, default=16, help="batch size when testing") parser.add_argument( '--selected_attrs', type=str, default="Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young", help="the attributes we selected to change") parser.add_argument( '--n_layers', type=int, default=5, help="default layers in generotor") parser.add_argument( '--gru_n_layers', type=int, default=4, help="default layers of GRU in generotor") parser.add_argument( '--dis_norm', type=str, default=None, help="the normalization in discriminator, choose in [None, instance_norm]" ) parser.add_argument( '--enable_ce', action='store_true', help="if set, run the tasks with continuous evaluation logs") return parser def __init__(self, cfg=None, train_reader=None, test_reader=None, batch_num=1, id2name=None): self.cfg = cfg self.train_reader = train_reader self.test_reader = test_reader self.batch_num = batch_num def build_model(self): data_shape = [None, 3, self.cfg.image_size, self.cfg.image_size] image_real = fluid.data( name='image_real', shape=data_shape, dtype='float32') label_org = fluid.data( name='label_org', shape=[None, self.cfg.c_dim], dtype='float32') label_trg = fluid.data( name='label_trg', shape=[None, self.cfg.c_dim], dtype='float32') label_org_ = fluid.data( name='label_org_', shape=[None, self.cfg.c_dim], dtype='float32') label_trg_ = fluid.data( name='label_trg_', shape=[None, self.cfg.c_dim], dtype='float32') # used for continuous evaluation if self.cfg.enable_ce: fluid.default_startup_program().random_seed = 90 test_gen_trainer = GTrainer(image_real, label_org, label_org_, label_trg, label_trg_, self.cfg, self.batch_num) loader = fluid.io.DataLoader.from_generator( feed_list=[image_real, label_org, label_trg], capacity=64, iterable=True, use_double_buffer=True) label_org_ = (label_org * 2.0 - 1.0) * self.cfg.thres_int label_trg_ = (label_trg * 2.0 - 1.0) * self.cfg.thres_int gen_trainer = GTrainer(image_real, label_org, label_org_, label_trg, label_trg_, self.cfg, self.batch_num) dis_trainer = DTrainer(image_real, label_org, label_org_, label_trg, label_trg_, self.cfg, self.batch_num) # prepare environment place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() loader.set_batch_generator( self.train_reader, places=fluid.cuda_places() if self.cfg.use_gpu else fluid.cpu_places()) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) if self.cfg.init_model: utility.init_checkpoints(self.cfg, gen_trainer, "net_G") utility.init_checkpoints(self.cfg, dis_trainer, "net_D") ### memory optim build_strategy = fluid.BuildStrategy() gen_trainer_program = fluid.CompiledProgram( gen_trainer.program).with_data_parallel( loss_name=gen_trainer.g_loss.name, build_strategy=build_strategy) dis_trainer_program = fluid.CompiledProgram( dis_trainer.program).with_data_parallel( loss_name=dis_trainer.d_loss.name, build_strategy=build_strategy) # used for continuous evaluation if self.cfg.enable_ce: gen_trainer_program.random_seed = 90 dis_trainer_program.random_seed = 90 total_train_batch = 0 # used for benchmark reader_cost_averager = timer.TimeAverager() batch_cost_averager = timer.TimeAverager() for epoch_id in range(self.cfg.epoch): batch_id = 0 batch_start = time.time() for data in loader(): if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark return reader_cost_averager.record(time.time() - batch_start) # optimize the discriminator network fetches = [ dis_trainer.d_loss.name, dis_trainer.d_loss_real.name, dis_trainer.d_loss_fake.name, dis_trainer.d_loss_cls.name, dis_trainer.d_loss_gp.name, ] d_loss, d_loss_real, d_loss_fake, d_loss_cls, d_loss_gp, = exe.run( dis_trainer_program, fetch_list=fetches, feed=data) if (batch_id + 1) % self.cfg.num_discriminator_time == 0: # optimize the generator network d_fetches = [ gen_trainer.g_loss_fake.name, gen_trainer.g_loss_rec.name, gen_trainer.g_loss_cls.name ] g_loss_fake, g_loss_rec, g_loss_cls = exe.run( gen_trainer_program, fetch_list=d_fetches, feed=data) print("epoch{}: batch{}: \n\ g_loss_fake: {:.5f}; g_loss_rec: {:.5f}; g_loss_cls: {:.5f}" .format(epoch_id, batch_id, g_loss_fake[0], g_loss_rec[0], g_loss_cls[0])) batch_cost_averager.record( time.time() - batch_start, num_samples=self.cfg.batch_size) if (batch_id + 1) % self.cfg.print_freq == 0: print("epoch{}: batch{}: \n\ d_loss: {:.5f}; d_loss_real: {:.5f}; d_loss_fake: {:.5f}; d_loss_cls: {:.5f}; d_loss_gp: {:.5f} \n\ batch_cost: {:.5f} sec, reader_cost: {:.5f} sec, ips: {:.5f} images/sec" .format(epoch_id, batch_id, d_loss[0], d_loss_real[0], d_loss_fake[0], d_loss_cls[0], d_loss_gp[0], batch_cost_averager.get_average(), reader_cost_averager.get_average(), batch_cost_averager.get_ips_average())) reader_cost_averager.reset() batch_cost_averager.reset() sys.stdout.flush() batch_id += 1 total_train_batch += 1 # used for benchmark batch_start = time.time() if self.cfg.enable_ce and batch_id == 100: break # profiler tools if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq: profiler.reset_profiler() elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5: return if self.cfg.run_test: image_name = fluid.data( name='image_name', shape=[None, self.cfg.n_samples], dtype='int32') test_loader = fluid.io.DataLoader.from_generator( feed_list=[image_real, label_org, label_trg, image_name], capacity=32, iterable=True, use_double_buffer=True) test_loader.set_batch_generator( self.test_reader, places=fluid.cuda_places() if self.cfg.use_gpu else fluid.cpu_places()) test_program = test_gen_trainer.infer_program utility.save_test_image(epoch_id, self.cfg, exe, place, test_program, test_gen_trainer, test_loader) if self.cfg.save_checkpoints: utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G") utility.checkpoints(epoch_id, self.cfg, dis_trainer, "net_D") # used for continuous evaluation if self.cfg.enable_ce: device_num = fluid.core.get_cuda_device_count( ) if self.cfg.use_gpu else 1 print("kpis\tstgan_g_loss_fake_card{}\t{}".format( device_num, g_loss_fake[0])) print("kpis\tstgan_g_loss_rec_card{}\t{}".format(device_num, g_loss_rec[0])) print("kpis\tstgan_g_loss_cls_card{}\t{}".format(device_num, g_loss_cls[0])) print("kpis\tstgan_d_loss_card{}\t{}".format(device_num, d_loss[ 0])) print("kpis\tstgan_d_loss_real_card{}\t{}".format( device_num, d_loss_real[0])) print("kpis\tstgan_d_loss_fake_card{}\t{}".format( device_num, d_loss_fake[0])) print("kpis\tstgan_d_loss_cls_card{}\t{}".format(device_num, d_loss_cls[0])) print("kpis\tstgan_d_loss_gp_card{}\t{}".format(device_num, d_loss_gp[0])) print("kpis\tstgan_Batch_time_cost_card{}\t{}".format( device_num, batch_time))