import paddle from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .discriminators.builder import build_discriminator from .losses import GANLoss # from ..modules.nn import L1Loss from ..solver import build_optimizer from ..utils.image_pool import ImagePool @MODELS.register() class CycleGANModel(BaseModel): """ This class implements the CycleGAN model, for learning image-to-image translation without paired data. The model training requires '--dataset_mode unaligned' dataset. By default, it uses a '--netG resnet_9blocks' ResNet generator, a '--netD basic' discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective ('--gan_mode lsgan'). CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf """ def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # specify the images you want to save/display. The training/test scripts will call visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # specify the models you want to save to the disk. The training/test scripts will call and . if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = build_generator(opt.model.generator) self.netG_B = build_generator(opt.model.generator) if self.isTrain: # define discriminators self.netD_A = build_discriminator(opt.model.discriminator) self.netD_B = build_discriminator(opt.model.discriminator) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc) self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images # define loss functions self.criterionGAN = GANLoss(opt.model.gan_mode, [[[[1.0]]]], [[[[0.0]]]])#.to(self.device) # define GAN loss. self.criterionCycle = paddle.nn.L1Loss() self.criterionIdt = paddle.nn.L1Loss() self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters()) self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters()) # self.optimizer_DA = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters()) # self.optimizer_DB = build_optimizer(opt.optimizer, parameter_list=self.netD_B.parameters()) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) # self.optimizers.append(self.optimizer_DA) # self.optimizers.append(self.optimizer_DB) self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])#A', 'optimizer_DB']) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ mode = 'train' if self.isTrain else 'test' AtoB = self.opt.dataset[mode].direction == 'AtoB' if AtoB: if 'A' in input: self.real_A = paddle.imperative.to_variable(input['A']) if 'B' in input: self.real_B = paddle.imperative.to_variable(input['B']) else: if 'B' in input: self.real_A = paddle.imperative.to_variable(input['B']) if 'A' in input: self.real_B = paddle.imperative.to_variable(input['A']) if 'A_paths' in input: self.image_paths = input['A_paths'] elif 'B_paths' in input: self.image_paths = input['B_paths'] # self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): """Run forward pass; called by both functions and .""" if hasattr(self, 'real_A'): self.fake_B = self.netG_A(self.real_A) # G_A(A) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) if hasattr(self, 'real_B'): self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) # def forward_test(self, input): # input = paddle.imperative.to_variable(input) # net_g = getattr(self, 'netG_' + self.opt.dataset.test.direction[0]) # return net_g(input) # def test(self, input): # """Forward function used in test time. # This function wraps function in no_grad() so we don't save intermediate steps for backprop # It also calls to produce additional visualization results # """ # with paddle.imperative.no_grad(): # return self.forward_test(input) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): """Calculate the loss for generators G_A and G_B""" lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss || G_B(G_A(A)) - A|| self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs self.optimizer_G.clear_gradients() #zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.minimize(self.loss_G) #step() # update G_A and G_B's weights # self.optimizer_G.clear_gradients() # self.optimizer_G.clear_gradients() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) # self.set_requires_grad(self.netD_A, True) self.optimizer_D.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate graidents for D_B self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) # update D_A and D_B's weights # self.backward_D_A() # calculate gradients for D_A # self.optimizer_DA.minimize(self.loss_D_A) #step() # update D_A and D_B's weights # self.optimizer_DA.clear_gradients() #zero_g # self.set_requires_grad(self.netD_B, True) # self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero # self.backward_D_B() # calculate graidents for D_B # self.optimizer_DB.minimize(self.loss_D_B) #step() # update D_A and D_B's weights # self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero