pix2pix_model.py 5.7 KB
Newer Older
L
LielinJiang 已提交
1
import paddle
L
LielinJiang 已提交
2
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
3 4 5 6 7 8
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
9

L
LielinJiang 已提交
10 11 12 13 14 15 16 17
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool


@MODELS.register()
class Pix2PixModel(BaseModel):
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

18
    The model training requires 'paired' dataset.
L
LielinJiang 已提交
19
    By default, it uses a '--netG unet256' U-Net generator,
20 21
    a '--netD basic' discriminator (from PatchGAN),
    and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
L
LielinJiang 已提交
22 23 24 25 26 27 28

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """
    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
29
            opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
L
LielinJiang 已提交
30 31 32 33 34 35
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
36
        # specify the models you want to save to the disk.
L
LielinJiang 已提交
37 38
        if self.isTrain:
            self.model_names = ['G', 'D']
39 40
        else:
            # during test time, only load G
L
LielinJiang 已提交
41
            self.model_names = ['G']
42

L
LielinJiang 已提交
43 44 45
        # define networks (both generator and discriminator)
        self.netG = build_generator(opt.model.generator)

46 47
        # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
        if self.isTrain:
L
LielinJiang 已提交
48
            self.netD = build_discriminator(opt.model.discriminator)
49

L
LielinJiang 已提交
50 51
        if self.isTrain:
            # define loss functions
52
            self.criterionGAN = GANLoss(opt.model.gan_mode)
L
LielinJiang 已提交
53
            self.criterionL1 = paddle.nn.L1Loss()
54 55

            # build optimizers
L
LielinJiang 已提交
56 57 58 59 60 61 62 63 64
            self.build_lr_scheduler()
            self.optimizer_G = build_optimizer(
                opt.optimizer,
                self.lr_scheduler,
                parameter_list=self.netG.parameters())
            self.optimizer_D = build_optimizer(
                opt.optimizer,
                self.lr_scheduler,
                parameter_list=self.netD.parameters())
L
LielinJiang 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
78

L
LielinJiang 已提交
79
        AtoB = self.opt.dataset.train.direction == 'AtoB'
L
LielinJiang 已提交
80 81
        self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
        self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])
L
LielinJiang 已提交
82
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
L
LielinJiang 已提交
83 84 85 86 87 88 89 90

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

    def forward_test(self, input):
        input = paddle.imperative.to_variable(input)
        return self.netG(input)
L
LielinJiang 已提交
91

L
LielinJiang 已提交
92 93 94
    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
95 96
        # use conditional GANs; we need to feed both input and output to the discriminator
        fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
L
LielinJiang 已提交
97 98 99 100 101 102 103 104
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = paddle.concat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
105 106 107 108 109 110
        if ParallelEnv().nranks > 1:
            self.loss_D = self.netD.scale_loss(self.loss_D)
            self.loss_D.backward()
            self.netD.apply_collective_grads()
        else:
            self.loss_D.backward()
L
LielinJiang 已提交
111 112 113 114 115 116 117 118

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
L
LielinJiang 已提交
119 120
        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_L1
L
LielinJiang 已提交
121 122
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
123 124 125 126 127 128 129

        if ParallelEnv().nranks > 1:
            self.loss_G = self.netG.scale_loss(self.loss_G)
            self.loss_G.backward()
            self.netG.apply_collective_grads()
        else:
            self.loss_G.backward()
L
LielinJiang 已提交
130 131

    def optimize_parameters(self):
132 133 134
        # compute fake images: G(A)
        self.forward()

L
LielinJiang 已提交
135
        # update D
136
        self.set_requires_grad(self.netD, True)
L
LielinJiang 已提交
137
        self.optimizer_D.clear_gradients()
138
        self.backward_D()
L
LielinJiang 已提交
139 140
        self.optimizer_D.minimize(self.loss_D)

L
LielinJiang 已提交
141
        # update G
L
LielinJiang 已提交
142
        self.set_requires_grad(self.netD, False)
143 144 145
        self.optimizer_G.clear_gradients()
        self.backward_G()
        self.optimizer_G.minimize(self.loss_G)