pix2pix_model.py 5.2 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19 20
import paddle
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
21
from .criterions import build_criterion
22

L
LielinJiang 已提交
23
from ..solver import build_optimizer
L
LielinJiang 已提交
24
from ..modules.init import init_weights
L
LielinJiang 已提交
25 26 27 28 29 30 31 32 33
from ..utils.image_pool import ImagePool


@MODELS.register()
class Pix2PixModel(BaseModel):
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """
34 35 36 37 38 39
    def __init__(self,
                 generator,
                 discriminator=None,
                 pixel_criterion=None,
                 gan_criterion=None,
                 direction='a2b'):
L
LielinJiang 已提交
40 41
        """Initialize the pix2pix class.

42 43 44 45 46
        Args:
            generator (dict): config of generator.
            discriminator (dict): config of discriminator.
            pixel_criterion (dict): config of pixel criterion.
            gan_criterion (dict): config of gan criterion.
L
LielinJiang 已提交
47
        """
48 49 50
        super(Pix2PixModel, self).__init__()

        self.direction = direction
L
LielinJiang 已提交
51
        # define networks (both generator and discriminator)
52
        self.nets['netG'] = build_generator(generator)
L
LielinJiang 已提交
53
        init_weights(self.nets['netG'])
L
LielinJiang 已提交
54

55
        # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
56 57
        if discriminator:
            self.nets['netD'] = build_discriminator(discriminator)
L
LielinJiang 已提交
58
            init_weights(self.nets['netD'])
59

60 61 62 63 64 65 66
        if pixel_criterion:
            self.pixel_criterion = build_criterion(pixel_criterion)

        if gan_criterion:
            self.gan_criterion = build_criterion(gan_criterion)

    def setup_input(self, input):
L
LielinJiang 已提交
67 68
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

69
        Args:
L
LielinJiang 已提交
70 71 72 73
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
74

75
        AtoB = self.direction == 'AtoB'
L
LielinJiang 已提交
76

L
LielinJiang 已提交
77 78 79 80
        self.real_A = paddle.fluid.dygraph.to_variable(
            input['A' if AtoB else 'B'])
        self.real_B = paddle.fluid.dygraph.to_variable(
            input['B' if AtoB else 'A'])
L
fix nan  
LielinJiang 已提交
81

82
        self.image_paths = input['A_path' if AtoB else 'B_path']
L
LielinJiang 已提交
83 84 85

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
86
        self.fake_B = self.nets['netG'](self.real_A)  # G(A)
L
LielinJiang 已提交
87

L
LielinJiang 已提交
88 89 90 91
        # put items to visual dict
        self.visual_items['fake_B'] = self.fake_B
        self.visual_items['real_A'] = self.real_A
        self.visual_items['real_B'] = self.real_B
L
LielinJiang 已提交
92

L
LielinJiang 已提交
93 94 95
    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
96 97
        # use conditional GANs; we need to feed both input and output to the discriminator
        fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
L
LielinJiang 已提交
98
        pred_fake = self.nets['netD'](fake_AB.detach())
99
        self.loss_D_fake = self.gan_criterion(pred_fake, False)
L
LielinJiang 已提交
100 101
        # Real
        real_AB = paddle.concat((self.real_A, self.real_B), 1)
L
LielinJiang 已提交
102
        pred_real = self.nets['netD'](real_AB)
103
        self.loss_D_real = self.gan_criterion(pred_real, True)
L
LielinJiang 已提交
104 105
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
L
LielinJiang 已提交
106 107

        self.loss_D.backward()
L
LielinJiang 已提交
108

L
lijianshe02 已提交
109 110
        self.losses['D_fake_loss'] = self.loss_D_fake
        self.losses['D_real_loss'] = self.loss_D_real
111

L
LielinJiang 已提交
112 113 114 115
    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
L
LielinJiang 已提交
116
        pred_fake = self.nets['netD'](fake_AB)
117
        self.loss_G_GAN = self.gan_criterion(pred_fake, True)
L
LielinJiang 已提交
118
        # Second, G(A) = B
119
        self.loss_G_L1 = self.pixel_criterion(self.fake_B, self.real_B)
L
fix nan  
LielinJiang 已提交
120

L
LielinJiang 已提交
121 122
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
123

L
LielinJiang 已提交
124
        self.loss_G.backward()
L
LielinJiang 已提交
125

L
lijianshe02 已提交
126 127
        self.losses['G_adv_loss'] = self.loss_G_GAN
        self.losses['G_L1_loss'] = self.loss_G_L1
128

129
    def train_iter(self, optimizers=None):
130 131 132
        # compute fake images: G(A)
        self.forward()

L
LielinJiang 已提交
133
        # update D
L
LielinJiang 已提交
134
        self.set_requires_grad(self.nets['netD'], True)
135
        optimizers['optimD'].clear_grad()
136
        self.backward_D()
137
        optimizers['optimD'].step()
L
LielinJiang 已提交
138

L
LielinJiang 已提交
139
        # update G
L
LielinJiang 已提交
140
        self.set_requires_grad(self.nets['netD'], False)
141
        optimizers['optimG'].clear_grad()
142
        self.backward_G()
143
        optimizers['optimG'].step()