dc_gan_model.py 4.3 KB
Newer Older
J
Jie Han 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
21
from .criterions import build_criterion
J
Jie Han 已提交
22 23 24 25 26
from ..modules.init import init_weights


@MODELS.register()
class DCGANModel(BaseModel):
27 28
    """
    This class implements the DCGAN model, for learning a distribution from input images.
J
Jie Han 已提交
29 30
    DCGAN paper: https://arxiv.org/pdf/1511.06434
    """
31
    def __init__(self, generator, discriminator=None, gan_criterion=None):
J
Jie Han 已提交
32
        """Initialize the DCGAN class.
33 34 35 36 37
        Args:
            generator (dict): config of generator.
            discriminator (dict): config of discriminator.
            pixel_criterion (dict): config of pixel criterion.
            gan_criterion (dict): config of gan criterion.
J
Jie Han 已提交
38
        """
39 40
        super(DCGANModel, self).__init__()
        self.gen_cfg = generator
J
Jie Han 已提交
41
        # define networks (both generator and discriminator)
42
        self.nets['netG'] = build_generator(generator)
J
Jie Han 已提交
43
        init_weights(self.nets['netG'])
44

J
Jie Han 已提交
45
        if self.is_train:
46
            self.nets['netD'] = build_discriminator(discriminator)
J
Jie Han 已提交
47 48
            init_weights(self.nets['netD'])

49 50 51 52
        if gan_criterion:
            self.gan_criterion = build_criterion(gan_criterion)

    def setup_input(self, input):
J
Jie Han 已提交
53 54
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

55
        Args:
J
Jie Han 已提交
56 57 58
            input (dict): include the data itself and its metadata information.
        """
        # get 1-channel gray image, or 3-channel color image
59 60
        self.real = paddle.to_tensor(input['A'])
        self.image_paths = input['A_path']
J
Jie Han 已提交
61 62

    def forward(self):
63
        """Run forward pass; called by both functions <train_iter> and <test_iter>."""
J
Jie Han 已提交
64 65

        # generate random noise and fake image
66 67 68
        self.z = paddle.rand(shape=(self.real.shape[0], self.gen_cfg.input_nz,
                                    1, 1))
        self.fake = self.nets['netG'](self.z)
J
Jie Han 已提交
69 70 71 72 73 74 75 76 77

        # put items to visual dict
        self.visual_items['real'] = self.real
        self.visual_items['fake'] = self.fake

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake
        pred_fake = self.nets['netD'](self.fake.detach())
78
        self.loss_D_fake = self.gan_criterion(pred_fake, False)
J
Jie Han 已提交
79 80

        pred_real = self.nets['netD'](self.real)
81
        self.loss_D_real = self.gan_criterion(pred_real, True)
J
Jie Han 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94

        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

        self.losses['D_fake_loss'] = self.loss_D_fake
        self.losses['D_real_loss'] = self.loss_D_real

    def backward_G(self):
        """Calculate GAN loss for the generator"""
        # G(A) should fake the discriminator
        pred_fake = self.nets['netD'](self.fake)
95
        self.loss_G_GAN = self.gan_criterion(pred_fake, True)
J
Jie Han 已提交
96 97 98 99 100 101 102 103

        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN

        self.loss_G.backward()

        self.losses['G_adv_loss'] = self.loss_G_GAN

104
    def train_iter(self, optimizers=None):
J
Jie Han 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        # compute fake images: G(A)
        self.forward()

        #update D
        self.set_requires_grad(self.nets['netD'], True)
        self.set_requires_grad(self.nets['netG'], False)
        self.optimizers['optimizer_D'].clear_grad()
        self.backward_D()
        self.optimizers['optimizer_D'].step()

        # update G
        self.set_requires_grad(self.nets['netD'], False)
        self.set_requires_grad(self.nets['netG'], True)
        self.optimizers['optimizer_G'].clear_grad()
        self.backward_G()
120
        self.optimizers['optimizer_G'].step()