cycle_gan_model.py 9.9 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19 20 21
import paddle
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
22

L
LielinJiang 已提交
23
from ..solver import build_optimizer
L
LielinJiang 已提交
24
from ..modules.init import init_weights
L
LielinJiang 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
from ..utils.image_pool import ImagePool


@MODELS.register()
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG resnet_9blocks' ResNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
L
LielinJiang 已提交
40
    def __init__(self, cfg):
L
LielinJiang 已提交
41 42 43
        """Initialize the CycleGAN class.

        Parameters:
44
            opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
L
LielinJiang 已提交
45
        """
L
LielinJiang 已提交
46
        super(CycleGANModel, self).__init__(cfg)
L
LielinJiang 已提交
47 48 49 50

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
L
LielinJiang 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63
        self.nets['netG_A'] = build_generator(cfg.model.generator)
        self.nets['netG_B'] = build_generator(cfg.model.generator)
        init_weights(self.nets['netG_A'])
        init_weights(self.nets['netG_B'])

        if self.is_train:  # define discriminators
            self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
            self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
            init_weights(self.nets['netD_A'])
            init_weights(self.nets['netD_B'])

        if self.is_train:
            if cfg.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
L
LielinJiang 已提交
64
                assert (
L
LielinJiang 已提交
65
                    cfg.dataset.train.input_nc == cfg.dataset.train.output_nc)
66
            # create image buffer to store previously generated images
L
LielinJiang 已提交
67
            self.fake_A_pool = ImagePool(cfg.dataset.train.pool_size)
68
            # create image buffer to store previously generated images
L
LielinJiang 已提交
69
            self.fake_B_pool = ImagePool(cfg.dataset.train.pool_size)
L
LielinJiang 已提交
70
            # define loss functions
L
LielinJiang 已提交
71
            self.criterionGAN = GANLoss(cfg.model.gan_mode)
L
LielinJiang 已提交
72
            self.criterionCycle = paddle.nn.L1Loss()
L
LielinJiang 已提交
73
            self.criterionIdt = paddle.nn.L1Loss()
L
LielinJiang 已提交
74 75

            self.build_lr_scheduler()
L
LielinJiang 已提交
76 77
            self.optimizers['optimizer_G'] = build_optimizer(
                cfg.optimizer,
L
LielinJiang 已提交
78
                self.lr_scheduler,
L
LielinJiang 已提交
79 80 81 82
                parameter_list=self.nets['netG_A'].parameters() +
                self.nets['netG_B'].parameters())
            self.optimizers['optimizer_D'] = build_optimizer(
                cfg.optimizer,
L
LielinJiang 已提交
83
                self.lr_scheduler,
L
LielinJiang 已提交
84 85
                parameter_list=self.nets['netD_A'].parameters() +
                self.nets['netD_B'].parameters())
L
LielinJiang 已提交
86 87 88 89

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

L
LielinJiang 已提交
90
        Args:
L
LielinJiang 已提交
91 92 93 94
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
L
LielinJiang 已提交
95 96
        mode = 'train' if self.is_train else 'test'
        AtoB = self.cfg.dataset[mode].direction == 'AtoB'
L
LielinJiang 已提交
97

L
LielinJiang 已提交
98 99
        if AtoB:
            if 'A' in input:
L
LielinJiang 已提交
100
                self.real_A = paddle.to_tensor(input['A'])
L
LielinJiang 已提交
101
            if 'B' in input:
L
LielinJiang 已提交
102
                self.real_B = paddle.to_tensor(input['B'])
L
LielinJiang 已提交
103 104
        else:
            if 'B' in input:
L
LielinJiang 已提交
105
                self.real_A = paddle.to_tensor(input['B'])
L
LielinJiang 已提交
106
            if 'A' in input:
L
LielinJiang 已提交
107
                self.real_B = paddle.to_tensor(input['A'])
L
LielinJiang 已提交
108 109 110 111 112

        if 'A_paths' in input:
            self.image_paths = input['A_paths']
        elif 'B_paths' in input:
            self.image_paths = input['B_paths']
113

L
LielinJiang 已提交
114 115
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
116
        if hasattr(self, 'real_A'):
L
LielinJiang 已提交
117 118 119 120 121 122 123
            self.fake_B = self.nets['netG_A'](self.real_A)  # G_A(A)
            self.rec_A = self.nets['netG_B'](self.fake_B)  # G_B(G_A(A))

            # visual
            self.visual_items['real_A'] = self.real_A
            self.visual_items['fake_B'] = self.fake_B
            self.visual_items['rec_A'] = self.rec_A
L
LielinJiang 已提交
124

L
LielinJiang 已提交
125
        if hasattr(self, 'real_B'):
L
LielinJiang 已提交
126 127 128 129 130 131 132
            self.fake_A = self.nets['netG_B'](self.real_B)  # G_B(B)
            self.rec_B = self.nets['netG_A'](self.fake_A)  # G_A(G_B(B))

            # visual
            self.visual_items['real_B'] = self.real_B
            self.visual_items['fake_A'] = self.fake_A
            self.visual_items['rec_B'] = self.rec_B
L
LielinJiang 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
L
LielinJiang 已提交
153 154

        loss_D.backward()
L
LielinJiang 已提交
155 156 157 158 159
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
L
LielinJiang 已提交
160 161
        self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
                                              fake_B)
L
lijianshe02 已提交
162
        self.losses['D_A_loss'] = self.loss_D_A
L
LielinJiang 已提交
163 164 165 166

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
L
LielinJiang 已提交
167 168
        self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
                                              fake_A)
L
lijianshe02 已提交
169
        self.losses['D_B_loss'] = self.loss_D_B
L
LielinJiang 已提交
170 171 172

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
L
LielinJiang 已提交
173 174 175
        lambda_idt = self.cfg.lambda_identity
        lambda_A = self.cfg.lambda_A
        lambda_B = self.cfg.lambda_B
L
LielinJiang 已提交
176 177 178
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
L
LielinJiang 已提交
179 180
            self.idt_A = self.nets['netG_A'](self.real_B)

L
LielinJiang 已提交
181 182
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
L
LielinJiang 已提交
183
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
L
LielinJiang 已提交
184 185 186 187 188 189
            self.idt_B = self.nets['netG_B'](self.real_A)

            # visual
            self.visual_items['idt_A'] = self.idt_A
            self.visual_items['idt_B'] = self.idt_B

L
LielinJiang 已提交
190 191
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
L
LielinJiang 已提交
192 193 194 195 196
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
L
LielinJiang 已提交
197 198
        self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_B),
                                          True)
L
LielinJiang 已提交
199
        # GAN loss D_B(G_B(B))
L
LielinJiang 已提交
200 201
        self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_A),
                                          True)
L
LielinJiang 已提交
202
        # Forward cycle loss || G_B(G_A(A)) - A||
L
LielinJiang 已提交
203 204
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
L
LielinJiang 已提交
205
        # Backward cycle loss || G_A(G_B(B)) - B||
L
LielinJiang 已提交
206 207
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
208

L
lijianshe02 已提交
209 210 211 212 213 214
        self.losses['G_idt_A_loss'] = self.loss_idt_A
        self.losses['G_idt_B_loss'] = self.loss_idt_B
        self.losses['G_A_adv_loss'] = self.loss_G_A
        self.losses['G_B_adv_loss'] = self.loss_G_B
        self.losses['G_A_cycle_loss'] = self.loss_cycle_A
        self.losses['G_B_cycle_loss'] = self.loss_cycle_B
L
LielinJiang 已提交
215 216
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
L
LielinJiang 已提交
217

L
LielinJiang 已提交
218
        self.loss_G.backward()
L
LielinJiang 已提交
219 220 221 222

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
223 224
        # compute fake images and reconstruction images.
        self.forward()
L
LielinJiang 已提交
225
        # G_A and G_B
226
        # Ds require no gradients when optimizing Gs
L
LielinJiang 已提交
227 228
        self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
                               False)
229
        # set G_A and G_B's gradients to zero
L
LielinJiang 已提交
230
        self.optimizers['optimizer_G'].clear_grad()
231 232 233
        # calculate gradients for G_A and G_B
        self.backward_G()
        # update G_A and G_B's weights
L
LielinJiang 已提交
234
        self.optimizers['optimizer_G'].step()
L
LielinJiang 已提交
235
        # D_A and D_B
L
LielinJiang 已提交
236
        self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True)
237 238

        # set D_A and D_B's gradients to zero
L
LielinJiang 已提交
239
        self.optimizers['optimizer_D'].clear_grad()
240 241 242 243
        # calculate gradients for D_A
        self.backward_D_A()
        # calculate graidents for D_B
        self.backward_D_B()
L
LielinJiang 已提交
244
        # update D_A and D_B's weights
L
LielinJiang 已提交
245
        self.optimizers['optimizer_D'].step()