cycle_gan_model.py 9.3 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7
import paddle
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
8

L
LielinJiang 已提交
9
from ..solver import build_optimizer
L
LielinJiang 已提交
10
from ..modules.init import init_weights
L
LielinJiang 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
from ..utils.image_pool import ImagePool


@MODELS.register()
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG resnet_9blocks' ResNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
L
LielinJiang 已提交
26
    def __init__(self, cfg):
L
LielinJiang 已提交
27 28 29
        """Initialize the CycleGAN class.

        Parameters:
30
            opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
L
LielinJiang 已提交
31
        """
L
LielinJiang 已提交
32
        super(CycleGANModel, self).__init__(cfg)
L
LielinJiang 已提交
33 34 35 36

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
L
LielinJiang 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49
        self.nets['netG_A'] = build_generator(cfg.model.generator)
        self.nets['netG_B'] = build_generator(cfg.model.generator)
        init_weights(self.nets['netG_A'])
        init_weights(self.nets['netG_B'])

        if self.is_train:  # define discriminators
            self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
            self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
            init_weights(self.nets['netD_A'])
            init_weights(self.nets['netD_B'])

        if self.is_train:
            if cfg.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
L
LielinJiang 已提交
50
                assert (
L
LielinJiang 已提交
51
                    cfg.dataset.train.input_nc == cfg.dataset.train.output_nc)
52
            # create image buffer to store previously generated images
L
LielinJiang 已提交
53
            self.fake_A_pool = ImagePool(cfg.dataset.train.pool_size)
54
            # create image buffer to store previously generated images
L
LielinJiang 已提交
55
            self.fake_B_pool = ImagePool(cfg.dataset.train.pool_size)
L
LielinJiang 已提交
56
            # define loss functions
L
LielinJiang 已提交
57
            self.criterionGAN = GANLoss(cfg.model.gan_mode)
L
LielinJiang 已提交
58
            self.criterionCycle = paddle.nn.L1Loss()
L
LielinJiang 已提交
59
            self.criterionIdt = paddle.nn.L1Loss()
L
LielinJiang 已提交
60 61

            self.build_lr_scheduler()
L
LielinJiang 已提交
62 63
            self.optimizers['optimizer_G'] = build_optimizer(
                cfg.optimizer,
L
LielinJiang 已提交
64
                self.lr_scheduler,
L
LielinJiang 已提交
65 66 67 68
                parameter_list=self.nets['netG_A'].parameters() +
                self.nets['netG_B'].parameters())
            self.optimizers['optimizer_D'] = build_optimizer(
                cfg.optimizer,
L
LielinJiang 已提交
69
                self.lr_scheduler,
L
LielinJiang 已提交
70 71
                parameter_list=self.nets['netD_A'].parameters() +
                self.nets['netD_B'].parameters())
L
LielinJiang 已提交
72 73 74 75

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

L
LielinJiang 已提交
76
        Args:
L
LielinJiang 已提交
77 78 79 80
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
L
LielinJiang 已提交
81 82
        mode = 'train' if self.is_train else 'test'
        AtoB = self.cfg.dataset[mode].direction == 'AtoB'
L
LielinJiang 已提交
83

L
LielinJiang 已提交
84 85
        if AtoB:
            if 'A' in input:
L
LielinJiang 已提交
86
                self.real_A = paddle.to_tensor(input['A'])
L
LielinJiang 已提交
87
            if 'B' in input:
L
LielinJiang 已提交
88
                self.real_B = paddle.to_tensor(input['B'])
L
LielinJiang 已提交
89 90
        else:
            if 'B' in input:
L
LielinJiang 已提交
91
                self.real_A = paddle.to_tensor(input['B'])
L
LielinJiang 已提交
92
            if 'A' in input:
L
LielinJiang 已提交
93
                self.real_B = paddle.to_tensor(input['A'])
L
LielinJiang 已提交
94 95 96 97 98

        if 'A_paths' in input:
            self.image_paths = input['A_paths']
        elif 'B_paths' in input:
            self.image_paths = input['B_paths']
99

L
LielinJiang 已提交
100 101
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
102
        if hasattr(self, 'real_A'):
L
LielinJiang 已提交
103 104 105 106 107 108 109
            self.fake_B = self.nets['netG_A'](self.real_A)  # G_A(A)
            self.rec_A = self.nets['netG_B'](self.fake_B)  # G_B(G_A(A))

            # visual
            self.visual_items['real_A'] = self.real_A
            self.visual_items['fake_B'] = self.fake_B
            self.visual_items['rec_A'] = self.rec_A
L
LielinJiang 已提交
110

L
LielinJiang 已提交
111
        if hasattr(self, 'real_B'):
L
LielinJiang 已提交
112 113 114 115 116 117 118
            self.fake_A = self.nets['netG_B'](self.real_B)  # G_B(B)
            self.rec_B = self.nets['netG_A'](self.fake_A)  # G_A(G_B(B))

            # visual
            self.visual_items['real_B'] = self.real_B
            self.visual_items['fake_A'] = self.fake_A
            self.visual_items['rec_B'] = self.rec_B
L
LielinJiang 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
L
LielinJiang 已提交
139 140

        loss_D.backward()
L
LielinJiang 已提交
141 142 143 144 145
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
L
LielinJiang 已提交
146 147
        self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
                                              fake_B)
L
lijianshe02 已提交
148
        self.losses['D_A_loss'] = self.loss_D_A
L
LielinJiang 已提交
149 150 151 152

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
L
LielinJiang 已提交
153 154
        self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
                                              fake_A)
L
lijianshe02 已提交
155
        self.losses['D_B_loss'] = self.loss_D_B
L
LielinJiang 已提交
156 157 158

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
L
LielinJiang 已提交
159 160 161
        lambda_idt = self.cfg.lambda_identity
        lambda_A = self.cfg.lambda_A
        lambda_B = self.cfg.lambda_B
L
LielinJiang 已提交
162 163 164
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
L
LielinJiang 已提交
165 166
            self.idt_A = self.nets['netG_A'](self.real_B)

L
LielinJiang 已提交
167 168
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
L
LielinJiang 已提交
169
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
L
LielinJiang 已提交
170 171 172 173 174 175
            self.idt_B = self.nets['netG_B'](self.real_A)

            # visual
            self.visual_items['idt_A'] = self.idt_A
            self.visual_items['idt_B'] = self.idt_B

L
LielinJiang 已提交
176 177
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
L
LielinJiang 已提交
178 179 180 181 182
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
L
LielinJiang 已提交
183 184
        self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_B),
                                          True)
L
LielinJiang 已提交
185
        # GAN loss D_B(G_B(B))
L
LielinJiang 已提交
186 187
        self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_A),
                                          True)
L
LielinJiang 已提交
188
        # Forward cycle loss || G_B(G_A(A)) - A||
L
LielinJiang 已提交
189 190
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
L
LielinJiang 已提交
191
        # Backward cycle loss || G_A(G_B(B)) - B||
L
LielinJiang 已提交
192 193
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
194

L
lijianshe02 已提交
195 196 197 198 199 200
        self.losses['G_idt_A_loss'] = self.loss_idt_A
        self.losses['G_idt_B_loss'] = self.loss_idt_B
        self.losses['G_A_adv_loss'] = self.loss_G_A
        self.losses['G_B_adv_loss'] = self.loss_G_B
        self.losses['G_A_cycle_loss'] = self.loss_cycle_A
        self.losses['G_B_cycle_loss'] = self.loss_cycle_B
L
LielinJiang 已提交
201 202
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
L
LielinJiang 已提交
203

L
LielinJiang 已提交
204
        self.loss_G.backward()
L
LielinJiang 已提交
205 206 207 208

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
209 210
        # compute fake images and reconstruction images.
        self.forward()
L
LielinJiang 已提交
211
        # G_A and G_B
212
        # Ds require no gradients when optimizing Gs
L
LielinJiang 已提交
213 214
        self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
                               False)
215
        # set G_A and G_B's gradients to zero
L
LielinJiang 已提交
216
        self.optimizers['optimizer_G'].clear_grad()
217 218 219
        # calculate gradients for G_A and G_B
        self.backward_G()
        # update G_A and G_B's weights
L
LielinJiang 已提交
220
        self.optimizers['optimizer_G'].step()
L
LielinJiang 已提交
221
        # D_A and D_B
L
LielinJiang 已提交
222
        self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True)
223 224

        # set D_A and D_B's gradients to zero
L
LielinJiang 已提交
225
        self.optimizers['optimizer_D'].clear_grad()
226 227 228 229
        # calculate gradients for D_A
        self.backward_D_A()
        # calculate graidents for D_B
        self.backward_D_B()
L
LielinJiang 已提交
230
        # update D_A and D_B's weights
L
LielinJiang 已提交
231
        self.optimizers['optimizer_D'].step()