pix2pix_model.py 5.2 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7
import paddle
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
8

L
LielinJiang 已提交
9 10 11 12 13 14 15 16
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool


@MODELS.register()
class Pix2PixModel(BaseModel):
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

17
    The model training requires 'paired' dataset.
L
LielinJiang 已提交
18
    By default, it uses a '--netG unet256' U-Net generator,
19 20
    a '--netD basic' discriminator (from PatchGAN),
    and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
L
LielinJiang 已提交
21 22 23 24 25 26 27 28

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """

    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
29
            opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
L
LielinJiang 已提交
30 31 32 33 34 35
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
36
        # specify the models you want to save to the disk.
L
LielinJiang 已提交
37 38
        if self.isTrain:
            self.model_names = ['G', 'D']
39 40
        else:
            # during test time, only load G
L
LielinJiang 已提交
41
            self.model_names = ['G']
42

L
LielinJiang 已提交
43 44 45
        # define networks (both generator and discriminator)
        self.netG = build_generator(opt.model.generator)

46 47 48

        # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
        if self.isTrain:
L
LielinJiang 已提交
49
            self.netD = build_discriminator(opt.model.discriminator)
50

L
LielinJiang 已提交
51 52 53

        if self.isTrain:
            # define loss functions
54
            self.criterionGAN = GANLoss(opt.model.gan_mode)
L
LielinJiang 已提交
55
            self.criterionL1 = paddle.nn.L1Loss()
56 57 58 59

            # build optimizers
            self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters())
            self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters()) 
L
LielinJiang 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
73

L
LielinJiang 已提交
74
        AtoB = self.opt.dataset.train.direction == 'AtoB'
L
LielinJiang 已提交
75 76 77
        self.real_A = paddle.imperative.to_variable(input['A' if AtoB else 'B'])
        self.real_B = paddle.imperative.to_variable(input['B' if AtoB else 'A'])
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
78
 
L
LielinJiang 已提交
79 80 81 82 83 84 85 86 87 88 89 90

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

    def forward_test(self, input):
        input = paddle.imperative.to_variable(input)
        return self.netG(input)
            
    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
91 92
        # use conditional GANs; we need to feed both input and output to the discriminator
        fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
L
LielinJiang 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = paddle.concat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        # self.loss_G = self.loss_G_L1
        self.loss_G.backward()

    def optimize_parameters(self):
117 118 119
        # compute fake images: G(A)
        self.forward()

L
LielinJiang 已提交
120
        # update D
121 122 123 124 125
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.clear_gradients() 
        self.backward_D()
        self.optimizer_D.minimize(self.loss_D) 
       
L
LielinJiang 已提交
126
        # update G
127 128 129 130
        self.set_requires_grad(self.netD, False) 
        self.optimizer_G.clear_gradients()
        self.backward_G()
        self.optimizer_G.minimize(self.loss_G)