cycle_gan_model.py 10.0 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15
import paddle
16
from .base_model import BaseModel, apply_to_static
L
LielinJiang 已提交
17 18 19 20

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
21
from .criterions import build_criterion
22

L
LielinJiang 已提交
23
from ..modules.init import init_weights
L
LielinJiang 已提交
24 25 26 27 28 29 30 31 32 33
from ..utils.image_pool import ImagePool


@MODELS.register()
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
34 35 36 37 38 39 40 41 42
    def __init__(self,
                 generator,
                 discriminator=None,
                 cycle_criterion=None,
                 idt_criterion=None,
                 gan_criterion=None,
                 pool_size=50,
                 direction='a2b',
                 lambda_a=10.,
43 44 45
                 lambda_b=10.,
                 to_static=False,
                 image_shape=None):
L
LielinJiang 已提交
46 47
        """Initialize the CycleGAN class.

48 49 50 51
        Args:
            generator (dict): config of generator.
            discriminator (dict): config of discriminator.
            cycle_criterion (dict): config of cycle criterion.
L
LielinJiang 已提交
52
        """
53 54 55
        super(CycleGANModel, self).__init__()

        self.direction = direction
L
LielinJiang 已提交
56

57 58 59
        self.lambda_a = lambda_a
        self.lambda_b = lambda_b
        # define generators
L
LielinJiang 已提交
60 61
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
62 63
        self.nets['netG_A'] = build_generator(generator)
        self.nets['netG_B'] = build_generator(generator)
64 65 66
        # set @to_static for benchmark, skip this by default.
        apply_to_static(to_static, image_shape, self.nets['netG_A'])
        apply_to_static(to_static, image_shape, self.nets['netG_B'])
L
LielinJiang 已提交
67 68 69
        init_weights(self.nets['netG_A'])
        init_weights(self.nets['netG_B'])

70 71 72 73
        # define discriminators
        if discriminator:
            self.nets['netD_A'] = build_discriminator(discriminator)
            self.nets['netD_B'] = build_discriminator(discriminator)
74 75 76
            # set @to_static for benchmark, skip this by default.
            apply_to_static(to_static, image_shape, self.nets['netD_A'])
            apply_to_static(to_static, image_shape, self.nets['netD_B'])
L
LielinJiang 已提交
77 78 79
            init_weights(self.nets['netD_A'])
            init_weights(self.nets['netD_B'])

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        # create image buffer to store previously generated images
        self.fake_A_pool = ImagePool(pool_size)
        # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(pool_size)

        # define loss functions
        if gan_criterion:
            self.gan_criterion = build_criterion(gan_criterion)

        if cycle_criterion:
            self.cycle_criterion = build_criterion(cycle_criterion)

        if idt_criterion:
            self.idt_criterion = build_criterion(idt_criterion)

    def setup_input(self, input):
L
LielinJiang 已提交
96 97
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

L
LielinJiang 已提交
98
        Args:
L
LielinJiang 已提交
99 100 101 102
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
103 104

        AtoB = self.direction == 'a2b'
L
LielinJiang 已提交
105

L
LielinJiang 已提交
106 107
        if AtoB:
            if 'A' in input:
L
LielinJiang 已提交
108
                self.real_A = paddle.to_tensor(input['A'])
L
LielinJiang 已提交
109
            if 'B' in input:
L
LielinJiang 已提交
110
                self.real_B = paddle.to_tensor(input['B'])
L
LielinJiang 已提交
111 112
        else:
            if 'B' in input:
L
LielinJiang 已提交
113
                self.real_A = paddle.to_tensor(input['B'])
L
LielinJiang 已提交
114
            if 'A' in input:
L
LielinJiang 已提交
115
                self.real_B = paddle.to_tensor(input['A'])
L
LielinJiang 已提交
116 117 118 119 120

        if 'A_paths' in input:
            self.image_paths = input['A_paths']
        elif 'B_paths' in input:
            self.image_paths = input['B_paths']
121

L
LielinJiang 已提交
122 123
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
124
        if hasattr(self, 'real_A'):
L
LielinJiang 已提交
125 126 127 128 129 130 131
            self.fake_B = self.nets['netG_A'](self.real_A)  # G_A(A)
            self.rec_A = self.nets['netG_B'](self.fake_B)  # G_B(G_A(A))

            # visual
            self.visual_items['real_A'] = self.real_A
            self.visual_items['fake_B'] = self.fake_B
            self.visual_items['rec_A'] = self.rec_A
L
LielinJiang 已提交
132

L
LielinJiang 已提交
133
        if hasattr(self, 'real_B'):
L
LielinJiang 已提交
134 135 136 137 138 139 140
            self.fake_A = self.nets['netG_B'](self.real_B)  # G_B(B)
            self.rec_B = self.nets['netG_A'](self.fake_A)  # G_A(G_B(B))

            # visual
            self.visual_items['real_B'] = self.real_B
            self.visual_items['fake_A'] = self.fake_A
            self.visual_items['rec_B'] = self.rec_B
L
LielinJiang 已提交
141 142 143 144

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

145 146 147 148 149 150 151
        Args:
            netD (Layer): the discriminator D
            real (paddle.Tensor): real images
            fake (paddle.Tensor): images generated by a generator

        Return:
            the discriminator loss.
L
LielinJiang 已提交
152 153 154 155 156

        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
157
        loss_D_real = self.gan_criterion(pred_real, True)
L
LielinJiang 已提交
158 159
        # Fake
        pred_fake = netD(fake.detach())
160
        loss_D_fake = self.gan_criterion(pred_fake, False)
L
LielinJiang 已提交
161 162
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
L
LielinJiang 已提交
163 164

        loss_D.backward()
L
LielinJiang 已提交
165 166 167 168 169
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
L
LielinJiang 已提交
170 171
        self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
                                              fake_B)
L
lijianshe02 已提交
172
        self.losses['D_A_loss'] = self.loss_D_A
L
LielinJiang 已提交
173 174 175 176

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
L
LielinJiang 已提交
177 178
        self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
                                              fake_A)
L
lijianshe02 已提交
179
        self.losses['D_B_loss'] = self.loss_D_B
L
LielinJiang 已提交
180 181 182 183

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        # Identity loss
184
        if self.idt_criterion:
L
LielinJiang 已提交
185
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
L
LielinJiang 已提交
186 187
            self.idt_A = self.nets['netG_A'](self.real_B)

188 189
            self.loss_idt_A = self.idt_criterion(self.idt_A,
                                                 self.real_B) * self.lambda_b
L
LielinJiang 已提交
190
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
L
LielinJiang 已提交
191 192 193 194 195 196
            self.idt_B = self.nets['netG_B'](self.real_A)

            # visual
            self.visual_items['idt_A'] = self.idt_A
            self.visual_items['idt_B'] = self.idt_B

197 198
            self.loss_idt_B = self.idt_criterion(self.idt_B,
                                                 self.real_A) * self.lambda_a
L
LielinJiang 已提交
199 200 201 202 203
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
204 205
        self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_B),
                                           True)
L
LielinJiang 已提交
206
        # GAN loss D_B(G_B(B))
207 208
        self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_A),
                                           True)
L
LielinJiang 已提交
209
        # Forward cycle loss || G_B(G_A(A)) - A||
210 211
        self.loss_cycle_A = self.cycle_criterion(self.rec_A,
                                                 self.real_A) * self.lambda_a
L
LielinJiang 已提交
212
        # Backward cycle loss || G_A(G_B(B)) - B||
213 214
        self.loss_cycle_B = self.cycle_criterion(self.rec_B,
                                                 self.real_B) * self.lambda_b
215

L
lijianshe02 已提交
216 217 218 219 220 221
        self.losses['G_idt_A_loss'] = self.loss_idt_A
        self.losses['G_idt_B_loss'] = self.loss_idt_B
        self.losses['G_A_adv_loss'] = self.loss_G_A
        self.losses['G_B_adv_loss'] = self.loss_G_B
        self.losses['G_A_cycle_loss'] = self.loss_cycle_A
        self.losses['G_B_cycle_loss'] = self.loss_cycle_B
L
LielinJiang 已提交
222 223
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
L
LielinJiang 已提交
224

L
LielinJiang 已提交
225
        self.loss_G.backward()
L
LielinJiang 已提交
226

227
    def train_iter(self, optimizers=None):
L
LielinJiang 已提交
228 229
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
230 231
        # compute fake images and reconstruction images.
        self.forward()
L
LielinJiang 已提交
232
        # G_A and G_B
233
        # Ds require no gradients when optimizing Gs
L
LielinJiang 已提交
234 235
        self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
                               False)
236
        # set G_A and G_B's gradients to zero
237
        optimizers['optimG'].clear_grad()
238 239 240
        # calculate gradients for G_A and G_B
        self.backward_G()
        # update G_A and G_B's weights
241
        self.optimizers['optimG'].step()
L
LielinJiang 已提交
242
        # D_A and D_B
L
LielinJiang 已提交
243
        self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True)
244 245

        # set D_A and D_B's gradients to zero
246
        optimizers['optimD'].clear_grad()
247 248 249 250
        # calculate gradients for D_A
        self.backward_D_A()
        # calculate graidents for D_B
        self.backward_D_B()
L
LielinJiang 已提交
251
        # update D_A and D_B's weights
252
        optimizers['optimD'].step()
L
lzzyzlbb 已提交
253 254 255 256 257 258 259 260 261 262


    def test_iter(self, metrics=None):
        self.nets['netG_A'].eval()
        self.forward()
        with paddle.no_grad():
            if metrics is not None:
                for metric in metrics.values():
                    metric.update(self.fake_B, self.real_B)
        self.nets['netG_A'].train()