esrgan_model.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .sr_model import BaseSRModel
from .builder import MODELS

from .criterions import build_criterion


@MODELS.register()
class ESRGAN(BaseSRModel):
    """
    This class implements the ESRGAN model.

    ESRGAN paper: https://arxiv.org/pdf/1809.00219.pdf
    """
B
Birdylx 已提交
32

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
    def __init__(self,
                 generator,
                 discriminator=None,
                 pixel_criterion=None,
                 perceptual_criterion=None,
                 gan_criterion=None):
        """Initialize the ESRGAN class.

        Args:
            generator (dict): config of generator.
            discriminator (dict): config of discriminator.
            pixel_criterion (dict): config of pixel criterion.
            perceptual_criterion (dict): config of perceptual criterion.
            gan_criterion (dict): config of gan criterion.
        """
        super(ESRGAN, self).__init__(generator)

        self.nets['generator'] = build_generator(generator)

        if discriminator:
            self.nets['discriminator'] = build_discriminator(discriminator)

        if pixel_criterion:
            self.pixel_criterion = build_criterion(pixel_criterion)

        if perceptual_criterion:
            self.perceptual_criterion = build_criterion(perceptual_criterion)

        if gan_criterion:
            self.gan_criterion = build_criterion(gan_criterion)

    def train_iter(self, optimizers=None):
        optimizers['optimG'].clear_grad()
        l_total = 0
        self.output = self.nets['generator'](self.lq)
        self.visual_items['output'] = self.output
        # pixel loss
        if self.pixel_criterion:
            l_pix = self.pixel_criterion(self.output, self.gt)
            l_total += l_pix
            self.losses['loss_pix'] = l_pix
        if self.perceptual_criterion:
            l_g_percep, l_g_style = self.perceptual_criterion(
                self.output, self.gt)
            # l_total += l_pix
            if l_g_percep is not None:
                l_total += l_g_percep
                self.losses['loss_percep'] = l_g_percep
            if l_g_style is not None:
                l_total += l_g_style
                self.losses['loss_style'] = l_g_style

        # gan loss (relativistic gan)
L
LielinJiang 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        if hasattr(self, 'gan_criterion'):
            self.set_requires_grad(self.nets['discriminator'], False)
            real_d_pred = self.nets['discriminator'](self.gt).detach()
            fake_g_pred = self.nets['discriminator'](self.output)
            l_g_real = self.gan_criterion(real_d_pred -
                                          paddle.mean(fake_g_pred),
                                          False,
                                          is_disc=False)
            l_g_fake = self.gan_criterion(fake_g_pred -
                                          paddle.mean(real_d_pred),
                                          True,
                                          is_disc=False)
            l_g_gan = (l_g_real + l_g_fake) / 2

            l_total += l_g_gan
            self.losses['l_g_gan'] = l_g_gan
            l_total.backward()
            optimizers['optimG'].step()

            self.set_requires_grad(self.nets['discriminator'], True)
            optimizers['optimD'].clear_grad()
            # real
            fake_d_pred = self.nets['discriminator'](self.output).detach()
            real_d_pred = self.nets['discriminator'](self.gt)
            l_d_real = self.gan_criterion(
                real_d_pred - paddle.mean(fake_d_pred), True,
                is_disc=True) * 0.5

            # fake
            fake_d_pred = self.nets['discriminator'](self.output.detach())
            l_d_fake = self.gan_criterion(
                fake_d_pred - paddle.mean(real_d_pred.detach()),
                False,
                is_disc=True) * 0.5

            (l_d_real + l_d_fake).backward()
            optimizers['optimD'].step()

            self.losses['l_d_real'] = l_d_real
            self.losses['l_d_fake'] = l_d_fake
            self.losses['out_d_real'] = paddle.mean(real_d_pred.detach())
            self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach())
        else:
            l_total.backward()
            optimizers['optimG'].step()
B
Birdylx 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214

    # amp training
    def train_iter_amp(self, optimizers=None, scalers=None, amp_level='O1'):
        optimizers['optimG'].clear_grad()
        l_total = 0

        # put loss computation in amp context
        with paddle.amp.auto_cast(enable=True, level=amp_level):
            self.output = self.nets['generator'](self.lq)
            self.visual_items['output'] = self.output
            # pixel loss
            if self.pixel_criterion:
                l_pix = self.pixel_criterion(self.output, self.gt)
                l_total += l_pix
                self.losses['loss_pix'] = l_pix
            if self.perceptual_criterion:
                l_g_percep, l_g_style = self.perceptual_criterion(
                    self.output, self.gt)
                # l_total += l_pix
                if l_g_percep is not None:
                    l_total += l_g_percep
                    self.losses['loss_percep'] = l_g_percep
                if l_g_style is not None:
                    l_total += l_g_style
                    self.losses['loss_style'] = l_g_style

        # gan loss (relativistic gan)
        if hasattr(self, 'gan_criterion'):
            self.set_requires_grad(self.nets['discriminator'], False)

            # put fwd and loss computation in amp context
            with paddle.amp.auto_cast(enable=True, level=amp_level):
                real_d_pred = self.nets['discriminator'](self.gt).detach()
                fake_g_pred = self.nets['discriminator'](self.output)
                l_g_real = self.gan_criterion(real_d_pred -
                                              paddle.mean(fake_g_pred),
                                              False,
                                              is_disc=False)
                l_g_fake = self.gan_criterion(fake_g_pred -
                                              paddle.mean(real_d_pred),
                                              True,
                                              is_disc=False)
                l_g_gan = (l_g_real + l_g_fake) / 2

                l_total += l_g_gan
                self.losses['l_g_gan'] = l_g_gan

            scaled_l_total = scalers[0].scale(l_total)
            scaled_l_total.backward()
            optimizers['optimG'].step()
            scalers[0].minimize(optimizers['optimG'], scaled_l_total)

            self.set_requires_grad(self.nets['discriminator'], True)
            optimizers['optimD'].clear_grad()

            with paddle.amp.auto_cast(enable=True, level=amp_level):
                # real
                fake_d_pred = self.nets['discriminator'](self.output).detach()
                real_d_pred = self.nets['discriminator'](self.gt)
                l_d_real = self.gan_criterion(
                    real_d_pred - paddle.mean(fake_d_pred), True,
                    is_disc=True) * 0.5

                # fake
                fake_d_pred = self.nets['discriminator'](self.output.detach())
                l_d_fake = self.gan_criterion(
                    fake_d_pred - paddle.mean(real_d_pred.detach()),
                    False,
                    is_disc=True) * 0.5

            l_temp = l_d_real + l_d_fake
            scaled_l_temp = scalers[1].scale(l_temp)
            scaled_l_temp.backward()
            scalers[0].minimize(optimizers['optimD'], scaled_l_temp)

            self.losses['l_d_real'] = l_d_real
            self.losses['l_d_fake'] = l_d_fake
            self.losses['out_d_real'] = paddle.mean(real_d_pred.detach())
            self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach())
        else:
            scaled_l_total = scalers[0].scale(l_total)
            scaled_l_total.backward()
            optimizers['optimG'].step()
            scalers[0].minimize(optimizers['optimG'], scaled_l_total)