cycle_gan_model.py 9.5 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19 20
import paddle
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
21
from .criterions import build_criterion
22

L
LielinJiang 已提交
23
from ..modules.init import init_weights
L
LielinJiang 已提交
24 25 26 27 28 29 30 31 32 33
from ..utils.image_pool import ImagePool


@MODELS.register()
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
34 35 36 37 38 39 40 41 42 43
    def __init__(self,
                 generator,
                 discriminator=None,
                 cycle_criterion=None,
                 idt_criterion=None,
                 gan_criterion=None,
                 pool_size=50,
                 direction='a2b',
                 lambda_a=10.,
                 lambda_b=10.):
L
LielinJiang 已提交
44 45
        """Initialize the CycleGAN class.

46 47 48 49
        Args:
            generator (dict): config of generator.
            discriminator (dict): config of discriminator.
            cycle_criterion (dict): config of cycle criterion.
L
LielinJiang 已提交
50
        """
51 52 53
        super(CycleGANModel, self).__init__()

        self.direction = direction
L
LielinJiang 已提交
54

55 56 57
        self.lambda_a = lambda_a
        self.lambda_b = lambda_b
        # define generators
L
LielinJiang 已提交
58 59
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
60 61
        self.nets['netG_A'] = build_generator(generator)
        self.nets['netG_B'] = build_generator(generator)
L
LielinJiang 已提交
62 63 64
        init_weights(self.nets['netG_A'])
        init_weights(self.nets['netG_B'])

65 66 67 68
        # define discriminators
        if discriminator:
            self.nets['netD_A'] = build_discriminator(discriminator)
            self.nets['netD_B'] = build_discriminator(discriminator)
L
LielinJiang 已提交
69 70 71
            init_weights(self.nets['netD_A'])
            init_weights(self.nets['netD_B'])

72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        # create image buffer to store previously generated images
        self.fake_A_pool = ImagePool(pool_size)
        # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(pool_size)

        # define loss functions
        if gan_criterion:
            self.gan_criterion = build_criterion(gan_criterion)

        if cycle_criterion:
            self.cycle_criterion = build_criterion(cycle_criterion)

        if idt_criterion:
            self.idt_criterion = build_criterion(idt_criterion)

    def setup_input(self, input):
L
LielinJiang 已提交
88 89
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

L
LielinJiang 已提交
90
        Args:
L
LielinJiang 已提交
91 92 93 94
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
95 96

        AtoB = self.direction == 'a2b'
L
LielinJiang 已提交
97

L
LielinJiang 已提交
98 99
        if AtoB:
            if 'A' in input:
L
LielinJiang 已提交
100
                self.real_A = paddle.to_tensor(input['A'])
L
LielinJiang 已提交
101
            if 'B' in input:
L
LielinJiang 已提交
102
                self.real_B = paddle.to_tensor(input['B'])
L
LielinJiang 已提交
103 104
        else:
            if 'B' in input:
L
LielinJiang 已提交
105
                self.real_A = paddle.to_tensor(input['B'])
L
LielinJiang 已提交
106
            if 'A' in input:
L
LielinJiang 已提交
107
                self.real_B = paddle.to_tensor(input['A'])
L
LielinJiang 已提交
108 109 110 111 112

        if 'A_paths' in input:
            self.image_paths = input['A_paths']
        elif 'B_paths' in input:
            self.image_paths = input['B_paths']
113

L
LielinJiang 已提交
114 115
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
116
        if hasattr(self, 'real_A'):
L
LielinJiang 已提交
117 118 119 120 121 122 123
            self.fake_B = self.nets['netG_A'](self.real_A)  # G_A(A)
            self.rec_A = self.nets['netG_B'](self.fake_B)  # G_B(G_A(A))

            # visual
            self.visual_items['real_A'] = self.real_A
            self.visual_items['fake_B'] = self.fake_B
            self.visual_items['rec_A'] = self.rec_A
L
LielinJiang 已提交
124

L
LielinJiang 已提交
125
        if hasattr(self, 'real_B'):
L
LielinJiang 已提交
126 127 128 129 130 131 132
            self.fake_A = self.nets['netG_B'](self.real_B)  # G_B(B)
            self.rec_B = self.nets['netG_A'](self.fake_A)  # G_A(G_B(B))

            # visual
            self.visual_items['real_B'] = self.real_B
            self.visual_items['fake_A'] = self.fake_A
            self.visual_items['rec_B'] = self.rec_B
L
LielinJiang 已提交
133 134 135 136

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

137 138 139 140 141 142 143
        Args:
            netD (Layer): the discriminator D
            real (paddle.Tensor): real images
            fake (paddle.Tensor): images generated by a generator

        Return:
            the discriminator loss.
L
LielinJiang 已提交
144 145 146 147 148

        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
149
        loss_D_real = self.gan_criterion(pred_real, True)
L
LielinJiang 已提交
150 151
        # Fake
        pred_fake = netD(fake.detach())
152
        loss_D_fake = self.gan_criterion(pred_fake, False)
L
LielinJiang 已提交
153 154
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
L
LielinJiang 已提交
155 156

        loss_D.backward()
L
LielinJiang 已提交
157 158 159 160 161
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
L
LielinJiang 已提交
162 163
        self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
                                              fake_B)
L
lijianshe02 已提交
164
        self.losses['D_A_loss'] = self.loss_D_A
L
LielinJiang 已提交
165 166 167 168

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
L
LielinJiang 已提交
169 170
        self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
                                              fake_A)
L
lijianshe02 已提交
171
        self.losses['D_B_loss'] = self.loss_D_B
L
LielinJiang 已提交
172 173 174 175

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        # Identity loss
176
        if self.idt_criterion:
L
LielinJiang 已提交
177
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
L
LielinJiang 已提交
178 179
            self.idt_A = self.nets['netG_A'](self.real_B)

180 181
            self.loss_idt_A = self.idt_criterion(self.idt_A,
                                                 self.real_B) * self.lambda_b
L
LielinJiang 已提交
182
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
L
LielinJiang 已提交
183 184 185 186 187 188
            self.idt_B = self.nets['netG_B'](self.real_A)

            # visual
            self.visual_items['idt_A'] = self.idt_A
            self.visual_items['idt_B'] = self.idt_B

189 190
            self.loss_idt_B = self.idt_criterion(self.idt_B,
                                                 self.real_A) * self.lambda_a
L
LielinJiang 已提交
191 192 193 194 195
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
196 197
        self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_B),
                                           True)
L
LielinJiang 已提交
198
        # GAN loss D_B(G_B(B))
199 200
        self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_A),
                                           True)
L
LielinJiang 已提交
201
        # Forward cycle loss || G_B(G_A(A)) - A||
202 203
        self.loss_cycle_A = self.cycle_criterion(self.rec_A,
                                                 self.real_A) * self.lambda_a
L
LielinJiang 已提交
204
        # Backward cycle loss || G_A(G_B(B)) - B||
205 206
        self.loss_cycle_B = self.cycle_criterion(self.rec_B,
                                                 self.real_B) * self.lambda_b
207

L
lijianshe02 已提交
208 209 210 211 212 213
        self.losses['G_idt_A_loss'] = self.loss_idt_A
        self.losses['G_idt_B_loss'] = self.loss_idt_B
        self.losses['G_A_adv_loss'] = self.loss_G_A
        self.losses['G_B_adv_loss'] = self.loss_G_B
        self.losses['G_A_cycle_loss'] = self.loss_cycle_A
        self.losses['G_B_cycle_loss'] = self.loss_cycle_B
L
LielinJiang 已提交
214 215
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
L
LielinJiang 已提交
216

L
LielinJiang 已提交
217
        self.loss_G.backward()
L
LielinJiang 已提交
218

219
    def train_iter(self, optimizers=None):
L
LielinJiang 已提交
220 221
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
222 223
        # compute fake images and reconstruction images.
        self.forward()
L
LielinJiang 已提交
224
        # G_A and G_B
225
        # Ds require no gradients when optimizing Gs
L
LielinJiang 已提交
226 227
        self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
                               False)
228
        # set G_A and G_B's gradients to zero
229
        optimizers['optimG'].clear_grad()
230 231 232
        # calculate gradients for G_A and G_B
        self.backward_G()
        # update G_A and G_B's weights
233
        self.optimizers['optimG'].step()
L
LielinJiang 已提交
234
        # D_A and D_B
L
LielinJiang 已提交
235
        self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True)
236 237

        # set D_A and D_B's gradients to zero
238
        optimizers['optimD'].clear_grad()
239 240 241 242
        # calculate gradients for D_A
        self.backward_D_A()
        # calculate graidents for D_B
        self.backward_D_B()
L
LielinJiang 已提交
243
        # update D_A and D_B's weights
244
        optimizers['optimD'].step()
L
lzzyzlbb 已提交
245 246 247 248 249 250 251 252 253 254


    def test_iter(self, metrics=None):
        self.nets['netG_A'].eval()
        self.forward()
        with paddle.no_grad():
            if metrics is not None:
                for metric in metrics.values():
                    metric.update(self.fake_B, self.real_B)
        self.nets['netG_A'].train()