dc_gan_model.py 4.9 KB
Newer Older
J
Jie Han 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss

from ..solver import build_optimizer
from ..modules.init import init_weights


@MODELS.register()
class DCGANModel(BaseModel):
    """ This class implements the DCGAN model, for learning a distribution from input images.

    The model training requires dataset.
    By default, it uses a '--netG DCGenerator' generator,
    a '--netD DCDiscriminator' discriminator,
    and a vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).

    DCGAN paper: https://arxiv.org/pdf/1511.06434
    """
    def __init__(self, cfg):
        """Initialize the DCGAN class.

        Parameters:
            opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
        """
        super(DCGANModel, self).__init__(cfg)
        # define networks (both generator and discriminator)
        self.nets['netG'] = build_generator(cfg.model.generator)
        init_weights(self.nets['netG'])
        self.cfg = cfg
        if self.is_train:
            self.nets['netD'] = build_discriminator(cfg.model.discriminator)
            init_weights(self.nets['netD'])

        if self.is_train:
            self.losses = {}
            # define loss functions
            self.criterionGAN = GANLoss(cfg.model.gan_mode)

            # build optimizers
            self.build_lr_scheduler()
            self.optimizers['optimizer_G'] = build_optimizer(
                cfg.optimizer,
                self.lr_scheduler,
                parameter_list=self.nets['netG'].parameters())
            self.optimizers['optimizer_D'] = build_optimizer(
                cfg.optimizer,
                self.lr_scheduler,
                parameter_list=self.nets['netD'].parameters())

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.
        """
        # get 1-channel gray image, or 3-channel color image
        self.real = paddle.to_tensor(input['A'][:,0:self.cfg.model.generator.input_nc,:,:])
        self.image_paths = input['A_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""

        # generate random noise and fake image
        self.z = paddle.rand(shape=(self.real.shape[0],self.cfg.model.generator.input_nz,1,1))
        self.fake = self.nets['netG'](self.z) 

        # put items to visual dict
        self.visual_items['real'] = self.real
        self.visual_items['fake'] = self.fake

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake
        pred_fake = self.nets['netD'](self.fake.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        pred_real = self.nets['netD'](self.real)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

        self.losses['D_fake_loss'] = self.loss_D_fake
        self.losses['D_real_loss'] = self.loss_D_real

    def backward_G(self):
        """Calculate GAN loss for the generator"""
        # G(A) should fake the discriminator
        pred_fake = self.nets['netD'](self.fake)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN

        self.loss_G.backward()

        self.losses['G_adv_loss'] = self.loss_G_GAN

    def optimize_parameters(self):
        # compute fake images: G(A)
        self.forward()

        #update D
        self.set_requires_grad(self.nets['netD'], True)
        self.set_requires_grad(self.nets['netG'], False)
        self.optimizers['optimizer_D'].clear_grad()
        self.backward_D()
        self.optimizers['optimizer_D'].step()

        # update G
        self.set_requires_grad(self.nets['netD'], False)
        self.set_requires_grad(self.nets['netG'], True)
        self.optimizers['optimizer_G'].clear_grad()
        self.backward_G()
        self.optimizers['optimizer_G'].step()