pix2pix_model.py 5.7 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19 20 21
import paddle
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
22

L
LielinJiang 已提交
23
from ..solver import build_optimizer
L
LielinJiang 已提交
24
from ..modules.init import init_weights
L
LielinJiang 已提交
25 26 27 28 29 30 31
from ..utils.image_pool import ImagePool


@MODELS.register()
class Pix2PixModel(BaseModel):
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

32
    The model training requires 'paired' dataset.
L
LielinJiang 已提交
33
    By default, it uses a '--netG unet256' U-Net generator,
34 35
    a '--netD basic' discriminator (from PatchGAN),
    and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
L
LielinJiang 已提交
36 37 38

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """
L
LielinJiang 已提交
39
    def __init__(self, cfg):
L
LielinJiang 已提交
40 41 42
        """Initialize the pix2pix class.

        Parameters:
43
            opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
L
LielinJiang 已提交
44
        """
L
LielinJiang 已提交
45
        super(Pix2PixModel, self).__init__(cfg)
L
LielinJiang 已提交
46
        # define networks (both generator and discriminator)
L
LielinJiang 已提交
47 48
        self.nets['netG'] = build_generator(cfg.model.generator)
        init_weights(self.nets['netG'])
L
LielinJiang 已提交
49

50
        # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
L
LielinJiang 已提交
51 52 53
        if self.is_train:
            self.nets['netD'] = build_discriminator(cfg.model.discriminator)
            init_weights(self.nets['netD'])
54

L
LielinJiang 已提交
55
        if self.is_train:
L
fix nan  
LielinJiang 已提交
56
            self.losses = {}
L
LielinJiang 已提交
57
            # define loss functions
L
LielinJiang 已提交
58
            self.criterionGAN = GANLoss(cfg.model.gan_mode)
L
LielinJiang 已提交
59
            self.criterionL1 = paddle.nn.L1Loss()
60 61

            # build optimizers
L
LielinJiang 已提交
62
            self.build_lr_scheduler()
L
LielinJiang 已提交
63 64
            self.optimizers['optimizer_G'] = build_optimizer(
                cfg.optimizer,
L
LielinJiang 已提交
65
                self.lr_scheduler,
L
LielinJiang 已提交
66 67 68
                parameter_list=self.nets['netG'].parameters())
            self.optimizers['optimizer_D'] = build_optimizer(
                cfg.optimizer,
L
LielinJiang 已提交
69
                self.lr_scheduler,
L
LielinJiang 已提交
70
                parameter_list=self.nets['netD'].parameters())
L
LielinJiang 已提交
71 72 73 74 75 76 77 78 79

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
80

L
LielinJiang 已提交
81 82
        AtoB = self.cfg.dataset.train.direction == 'AtoB'

L
LielinJiang 已提交
83 84 85 86
        self.real_A = paddle.fluid.dygraph.to_variable(
            input['A' if AtoB else 'B'])
        self.real_B = paddle.fluid.dygraph.to_variable(
            input['B' if AtoB else 'A'])
L
fix nan  
LielinJiang 已提交
87

L
LielinJiang 已提交
88
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
L
LielinJiang 已提交
89 90 91

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
92
        self.fake_B = self.nets['netG'](self.real_A)  # G(A)
L
LielinJiang 已提交
93

L
LielinJiang 已提交
94 95 96 97
        # put items to visual dict
        self.visual_items['fake_B'] = self.fake_B
        self.visual_items['real_A'] = self.real_A
        self.visual_items['real_B'] = self.real_B
L
LielinJiang 已提交
98

L
LielinJiang 已提交
99 100 101
    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
102 103
        # use conditional GANs; we need to feed both input and output to the discriminator
        fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
L
LielinJiang 已提交
104
        pred_fake = self.nets['netD'](fake_AB.detach())
L
LielinJiang 已提交
105 106 107
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = paddle.concat((self.real_A, self.real_B), 1)
L
LielinJiang 已提交
108
        pred_real = self.nets['netD'](real_AB)
L
LielinJiang 已提交
109 110 111
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
L
LielinJiang 已提交
112 113

        self.loss_D.backward()
L
LielinJiang 已提交
114

L
lijianshe02 已提交
115 116
        self.losses['D_fake_loss'] = self.loss_D_fake
        self.losses['D_real_loss'] = self.loss_D_real
117

L
LielinJiang 已提交
118 119 120 121
    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
L
LielinJiang 已提交
122
        pred_fake = self.nets['netD'](fake_AB)
L
LielinJiang 已提交
123 124
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
L
LielinJiang 已提交
125
        self.loss_G_L1 = self.criterionL1(self.fake_B,
L
LielinJiang 已提交
126
                                          self.real_B) * self.cfg.lambda_L1
L
fix nan  
LielinJiang 已提交
127

L
LielinJiang 已提交
128 129
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
130

L
LielinJiang 已提交
131
        self.loss_G.backward()
L
LielinJiang 已提交
132

L
lijianshe02 已提交
133 134
        self.losses['G_adv_loss'] = self.loss_G_GAN
        self.losses['G_L1_loss'] = self.loss_G_L1
135

L
LielinJiang 已提交
136
    def optimize_parameters(self):
137 138 139
        # compute fake images: G(A)
        self.forward()

L
LielinJiang 已提交
140
        # update D
L
LielinJiang 已提交
141 142
        self.set_requires_grad(self.nets['netD'], True)
        self.optimizers['optimizer_D'].clear_grad()
143
        self.backward_D()
L
LielinJiang 已提交
144
        self.optimizers['optimizer_D'].step()
L
LielinJiang 已提交
145

L
LielinJiang 已提交
146
        # update G
L
LielinJiang 已提交
147 148
        self.set_requires_grad(self.nets['netD'], False)
        self.optimizers['optimizer_G'].clear_grad()
149
        self.backward_G()
L
LielinJiang 已提交
150
        self.optimizers['optimizer_G'].step()