ugatit_model.py 11.0 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
22
from .criterions import build_criterion
L
LielinJiang 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36

from ..solver import build_optimizer
from ..modules.nn import RhoClipper
from ..modules.init import init_weights
from ..utils.image_pool import ImagePool


@MODELS.register()
class UGATITModel(BaseModel):
    """
    This class implements the UGATIT model, for learning image-to-image translation without paired data.

    UGATIT paper: https://arxiv.org/pdf/1907.10830.pdf
    """
37 38 39 40 41 42 43 44 45 46 47 48
    def __init__(self,
                 generator,
                 discriminator_g=None,
                 discriminator_l=None,
                 l1_criterion=None,
                 mse_criterion=None,
                 bce_criterion=None,
                 direction='a2b',
                 adv_weight=1.0,
                 cycle_weight=10.0,
                 identity_weight=10.0,
                 cam_weight=1000.0):
L
LielinJiang 已提交
49 50 51 52 53
        """Initialize the CycleGAN class.

        Parameters:
            opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
        """
54 55 56 57 58 59
        super(UGATITModel, self).__init__()
        self.adv_weight = adv_weight
        self.cycle_weight = cycle_weight
        self.identity_weight = identity_weight
        self.cam_weight = cam_weight
        self.direction = direction
L
LielinJiang 已提交
60 61
        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
62 63
        self.nets['genA2B'] = build_generator(generator)
        self.nets['genB2A'] = build_generator(generator)
L
LielinJiang 已提交
64 65 66
        init_weights(self.nets['genA2B'])
        init_weights(self.nets['genB2A'])

67
        if discriminator_g and discriminator_l:
L
LielinJiang 已提交
68
            # define discriminators
69 70 71 72
            self.nets['disGA'] = build_discriminator(discriminator_g)
            self.nets['disGB'] = build_discriminator(discriminator_g)
            self.nets['disLA'] = build_discriminator(discriminator_l)
            self.nets['disLB'] = build_discriminator(discriminator_l)
L
LielinJiang 已提交
73 74 75 76 77
            init_weights(self.nets['disGA'])
            init_weights(self.nets['disGB'])
            init_weights(self.nets['disLA'])
            init_weights(self.nets['disLB'])

78 79 80 81 82 83 84 85 86 87 88
        # define loss functions
        if l1_criterion:
            self.L1_loss = build_criterion(l1_criterion)
        if bce_criterion:
            self.BCE_loss = build_criterion(bce_criterion)
        if mse_criterion:
            self.MSE_loss = build_criterion(mse_criterion)

        self.Rho_clipper = RhoClipper(0, 1)

    def setup_input(self, input):
L
LielinJiang 已提交
89 90 91 92 93 94 95
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Args:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
96
        AtoB = self.direction == 'a2b'
L
LielinJiang 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114

        if AtoB:
            if 'A' in input:
                self.real_A = paddle.to_tensor(input['A'])
            if 'B' in input:
                self.real_B = paddle.to_tensor(input['B'])
        else:
            if 'B' in input:
                self.real_A = paddle.to_tensor(input['B'])
            if 'A' in input:
                self.real_B = paddle.to_tensor(input['A'])

        if 'A_paths' in input:
            self.image_paths = input['A_paths']
        elif 'B_paths' in input:
            self.image_paths = input['B_paths']

    def forward(self):
115
        """Run forward pass; called by both functions <train_iter> and <test_iter>."""
L
LielinJiang 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129
        if hasattr(self, 'real_A'):
            self.fake_A2B, _, _ = self.nets['genA2B'](self.real_A)

            # visual
            self.visual_items['real_A'] = self.real_A
            self.visual_items['fake_A2B'] = self.fake_A2B

        if hasattr(self, 'real_B'):
            self.fake_B2A, _, _ = self.nets['genB2A'](self.real_B)

            # visual
            self.visual_items['real_B'] = self.real_B
            self.visual_items['fake_B2A'] = self.fake_B2A

130
    def test_iter(self, metrics=None):
L
LielinJiang 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        self.nets['genA2B'].eval()
        self.nets['genB2A'].eval()
        with paddle.no_grad():
            self.forward()
            self.compute_visuals()

        self.nets['genA2B'].train()
        self.nets['genB2A'].train()

145
    def train_iter(self, optimizers=None):
L
LielinJiang 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        def _criterion(loss_func, logit, is_real):
            if is_real:
                target = paddle.ones_like(logit)
            else:
                target = paddle.zeros_like(logit)
            return loss_func(logit, target)

        # forward
        # compute fake images and reconstruction images.
        self.forward()

        # update D
159
        optimizers['optimD'].clear_grad()
L
LielinJiang 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        real_GA_logit, real_GA_cam_logit, _ = self.nets['disGA'](self.real_A)
        real_LA_logit, real_LA_cam_logit, _ = self.nets['disLA'](self.real_A)
        real_GB_logit, real_GB_cam_logit, _ = self.nets['disGB'](self.real_B)
        real_LB_logit, real_LB_cam_logit, _ = self.nets['disLB'](self.real_B)

        fake_GA_logit, fake_GA_cam_logit, _ = self.nets['disGA'](self.fake_B2A)
        fake_LA_logit, fake_LA_cam_logit, _ = self.nets['disLA'](self.fake_B2A)
        fake_GB_logit, fake_GB_cam_logit, _ = self.nets['disGB'](self.fake_A2B)
        fake_LB_logit, fake_LB_cam_logit, _ = self.nets['disLB'](self.fake_A2B)

        D_ad_loss_GA = _criterion(self.MSE_loss,
                                  real_GA_logit, True) + _criterion(
                                      self.MSE_loss, fake_GA_logit, False)

        D_ad_cam_loss_GA = _criterion(
            self.MSE_loss, real_GA_cam_logit, True) + _criterion(
                self.MSE_loss, fake_GA_cam_logit, False)

        D_ad_loss_LA = _criterion(self.MSE_loss,
                                  real_LA_logit, True) + _criterion(
                                      self.MSE_loss, fake_LA_logit, False)

        D_ad_cam_loss_LA = _criterion(
            self.MSE_loss, real_LA_cam_logit, True) + _criterion(
                self.MSE_loss, fake_LA_cam_logit, False)

        D_ad_loss_GB = _criterion(self.MSE_loss,
                                  real_GB_logit, True) + _criterion(
                                      self.MSE_loss, fake_GB_logit, False)

        D_ad_cam_loss_GB = _criterion(
            self.MSE_loss, real_GB_cam_logit, True) + _criterion(
                self.MSE_loss, fake_GB_cam_logit, False)

        D_ad_loss_LB = _criterion(self.MSE_loss,
                                  real_LB_logit, True) + _criterion(
                                      self.MSE_loss, fake_LB_logit, False)

        D_ad_cam_loss_LB = _criterion(
            self.MSE_loss, real_LB_cam_logit, True) + _criterion(
                self.MSE_loss, fake_LB_cam_logit, False)

202 203 204 205
        D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                      D_ad_loss_LA + D_ad_cam_loss_LA)
        D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                      D_ad_loss_LB + D_ad_cam_loss_LB)
L
LielinJiang 已提交
206 207 208

        Discriminator_loss = D_loss_A + D_loss_B
        Discriminator_loss.backward()
209
        optimizers['optimD'].step()
L
LielinJiang 已提交
210 211

        # update G
212
        optimizers['optimG'].clear_grad()
L
LielinJiang 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250

        fake_A2B, fake_A2B_cam_logit, _ = self.nets['genA2B'](self.real_A)
        fake_B2A, fake_B2A_cam_logit, _ = self.nets['genB2A'](self.real_B)

        fake_A2B2A, _, _ = self.nets['genB2A'](fake_A2B)
        fake_B2A2B, _, _ = self.nets['genA2B'](fake_B2A)

        fake_A2A, fake_A2A_cam_logit, _ = self.nets['genB2A'](self.real_A)
        fake_B2B, fake_B2B_cam_logit, _ = self.nets['genA2B'](self.real_B)

        fake_GA_logit, fake_GA_cam_logit, _ = self.nets['disGA'](fake_B2A)
        fake_LA_logit, fake_LA_cam_logit, _ = self.nets['disLA'](fake_B2A)
        fake_GB_logit, fake_GB_cam_logit, _ = self.nets['disGB'](fake_A2B)
        fake_LB_logit, fake_LB_cam_logit, _ = self.nets['disLB'](fake_A2B)

        G_ad_loss_GA = _criterion(self.MSE_loss, fake_GA_logit, True)
        G_ad_cam_loss_GA = _criterion(self.MSE_loss, fake_GA_cam_logit, True)
        G_ad_loss_LA = _criterion(self.MSE_loss, fake_LA_logit, True)
        G_ad_cam_loss_LA = _criterion(self.MSE_loss, fake_LA_cam_logit, True)
        G_ad_loss_GB = _criterion(self.MSE_loss, fake_GB_logit, True)
        G_ad_cam_loss_GB = _criterion(self.MSE_loss, fake_GB_cam_logit, True)
        G_ad_loss_LB = _criterion(self.MSE_loss, fake_LB_logit, True)
        G_ad_cam_loss_LB = _criterion(self.MSE_loss, fake_LB_cam_logit, True)

        G_recon_loss_A = self.L1_loss(fake_A2B2A, self.real_A)
        G_recon_loss_B = self.L1_loss(fake_B2A2B, self.real_B)

        G_identity_loss_A = self.L1_loss(fake_A2A, self.real_A)
        G_identity_loss_B = self.L1_loss(fake_B2B, self.real_B)

        G_cam_loss_A = _criterion(self.BCE_loss,
                                  fake_B2A_cam_logit, True) + _criterion(
                                      self.BCE_loss, fake_A2A_cam_logit, False)

        G_cam_loss_B = _criterion(self.BCE_loss,
                                  fake_A2B_cam_logit, True) + _criterion(
                                      self.BCE_loss, fake_B2B_cam_logit, False)

251
        G_loss_A = self.adv_weight * (
L
LielinJiang 已提交
252
            G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA
253 254
        ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
        G_loss_B = self.adv_weight * (
L
LielinJiang 已提交
255
            G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB
256
        ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B
L
LielinJiang 已提交
257 258 259

        Generator_loss = G_loss_A + G_loss_B
        Generator_loss.backward()
260
        optimizers['optimG'].step()
L
LielinJiang 已提交
261 262 263 264 265 266 267

        # clip parameter of AdaILN and ILN, applied after optimizer step
        self.nets['genA2B'].apply(self.Rho_clipper)
        self.nets['genB2A'].apply(self.Rho_clipper)

        self.losses['discriminator_loss'] = Discriminator_loss
        self.losses['generator_loss'] = Generator_loss