makeup_model.py 17.7 KB
Newer Older
Q
qingqing01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
L
lijianshe02 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LielinJiang 已提交
14
import numpy as np
L
lijianshe02 已提交
15

L
lijianshe02 已提交
16 17 18
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
L
LielinJiang 已提交
19
from paddle.vision.models import vgg16
L
lijianshe02 已提交
20 21 22 23 24 25
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
L
lijianshe02 已提交
26
from ..modules.init import init_weights
L
lijianshe02 已提交
27 28 29 30 31 32 33 34 35
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool
from ..utils.preprocess import *
from ..datasets.makeup_dataset import MakeupDataset


@MODELS.register()
class MakeupModel(BaseModel):
    """
L
LielinJiang 已提交
36
    PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf
L
lijianshe02 已提交
37
    """
L
LielinJiang 已提交
38 39
    def __init__(self, cfg):
        """Initialize the PSGAN class.
L
lijianshe02 已提交
40 41

        Parameters:
L
LielinJiang 已提交
42
            cfg (dict)-- config of model.
L
lijianshe02 已提交
43
        """
L
LielinJiang 已提交
44
        super(MakeupModel, self).__init__(cfg)
L
lijianshe02 已提交
45 46 47 48

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
L
LielinJiang 已提交
49 50
        self.nets['netG'] = build_generator(cfg.model.generator)
        init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)
L
lijianshe02 已提交
51

L
LielinJiang 已提交
52 53 54 55 56 57 58
        if self.is_train:  # define discriminators
            vgg = vgg16(pretrained=True)
            self.vgg = vgg.features
            self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
            self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
            init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0)
            init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0)
L
lijianshe02 已提交
59 60

            self.fake_A_pool = ImagePool(
L
LielinJiang 已提交
61
                cfg.dataset.train.pool_size
L
lijianshe02 已提交
62 63
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
L
LielinJiang 已提交
64
                cfg.dataset.train.pool_size
L
lijianshe02 已提交
65 66 67
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = GANLoss(
L
LielinJiang 已提交
68
                cfg.model.gan_mode)  #.to(self.device)  # define GAN loss.
L
lijianshe02 已提交
69 70 71 72 73 74
            self.criterionCycle = paddle.nn.L1Loss()
            self.criterionIdt = paddle.nn.L1Loss()
            self.criterionL1 = paddle.nn.L1Loss()
            self.criterionL2 = paddle.nn.MSELoss()

            self.build_lr_scheduler()
L
LielinJiang 已提交
75 76
            self.optimizers['optimizer_G'] = build_optimizer(
                cfg.optimizer,
L
lijianshe02 已提交
77
                self.lr_scheduler,
L
LielinJiang 已提交
78 79 80
                parameter_list=self.nets['netG'].parameters())
            self.optimizers['optimizer_DA'] = build_optimizer(
                cfg.optimizer,
L
lijianshe02 已提交
81
                self.lr_scheduler,
L
LielinJiang 已提交
82 83 84
                parameter_list=self.nets['netD_A'].parameters())
            self.optimizers['optimizer_DB'] = build_optimizer(
                cfg.optimizer,
L
lijianshe02 已提交
85
                self.lr_scheduler,
L
LielinJiang 已提交
86
                parameter_list=self.nets['netD_B'].parameters())
L
lijianshe02 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        self.real_A = paddle.to_tensor(input['image_A'])
        self.real_B = paddle.to_tensor(input['image_B'])
        self.c_m = paddle.to_tensor(input['consis_mask'])
        self.P_A = paddle.to_tensor(input['P_A'])
        self.P_B = paddle.to_tensor(input['P_B'])
        self.mask_A_aug = paddle.to_tensor(input['mask_A_aug'])
        self.mask_B_aug = paddle.to_tensor(input['mask_B_aug'])
        self.c_m_t = paddle.transpose(self.c_m, perm=[0, 2, 1])
L
LielinJiang 已提交
104
        if self.is_train:
L
lijianshe02 已提交
105 106 107 108 109 110 111
            self.mask_A = paddle.to_tensor(input['mask_A'])
            self.mask_B = paddle.to_tensor(input['mask_B'])
            self.c_m_idt_a = paddle.to_tensor(input['consis_mask_idt_A'])
            self.c_m_idt_b = paddle.to_tensor(input['consis_mask_idt_B'])

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        self.fake_A, amm = self.nets['netG'](self.real_A, self.real_B, self.P_A,
                                             self.P_B, self.c_m,
                                             self.mask_A_aug,
                                             self.mask_B_aug)  # G_A(A)
        self.fake_B, _ = self.nets['netG'](self.real_B, self.real_A, self.P_B,
                                           self.P_A, self.c_m_t,
                                           self.mask_A_aug,
                                           self.mask_B_aug)  # G_A(A)
        self.rec_A, _ = self.nets['netG'](self.fake_A, self.real_A, self.P_A,
                                          self.P_A, self.c_m_idt_a,
                                          self.mask_A_aug,
                                          self.mask_B_aug)  # G_A(A)
        self.rec_B, _ = self.nets['netG'](self.fake_B, self.real_B, self.P_B,
                                          self.P_B, self.c_m_idt_b,
                                          self.mask_A_aug,
                                          self.mask_B_aug)  # G_A(A)

        # visual
        self.visual_items['real_A'] = self.real_A
        self.visual_items['fake_B'] = self.fake_B
        self.visual_items['rec_A'] = self.rec_A
        self.visual_items['real_B'] = self.real_B
        self.visual_items['fake_A'] = self.fake_A
        self.visual_items['rec_B'] = self.rec_B
L
lijianshe02 已提交
136 137 138 139 140

    def forward_test(self, input):
        '''
        not implement now
        '''
L
LielinJiang 已提交
141 142 143 144
        return self.nets['netG'](input['image_A'], input['image_B'],
                                 input['P_A'], input['P_B'],
                                 input['consis_mask'], input['mask_A_aug'],
                                 input['mask_B_aug'])
L
lijianshe02 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179

    def test(self, input):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with paddle.no_grad():
            return self.forward_test(input)

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
L
LielinJiang 已提交
180 181
        self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
                                              fake_B)
L
lijianshe02 已提交
182
        self.losses['D_A_loss'] = self.loss_D_A
L
lijianshe02 已提交
183 184 185 186

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
L
LielinJiang 已提交
187 188
        self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
                                              fake_A)
L
lijianshe02 已提交
189
        self.losses['D_B_loss'] = self.loss_D_B
L
lijianshe02 已提交
190 191 192

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
L
LielinJiang 已提交
193 194 195 196

        lambda_idt = self.cfg.lambda_identity
        lambda_A = self.cfg.lambda_A
        lambda_B = self.cfg.lambda_B
L
lijianshe02 已提交
197 198 199
        lambda_vgg = 5e-3
        # Identity loss
        if lambda_idt > 0:
L
LielinJiang 已提交
200 201 202 203
            self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A,
                                              self.P_A, self.P_A,
                                              self.c_m_idt_a, self.mask_A_aug,
                                              self.mask_B_aug)  # G_A(A)
L
lijianshe02 已提交
204 205
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_A) * lambda_A * lambda_idt
L
LielinJiang 已提交
206 207 208 209
            self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B,
                                              self.P_B, self.P_B,
                                              self.c_m_idt_b, self.mask_A_aug,
                                              self.mask_B_aug)  # G_A(A)
L
lijianshe02 已提交
210 211
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_B) * lambda_B * lambda_idt
L
LielinJiang 已提交
212 213 214 215

            # visual
            self.visual_items['idt_A'] = self.idt_A
            self.visual_items['idt_B'] = self.idt_B
L
lijianshe02 已提交
216 217 218 219 220
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
L
LielinJiang 已提交
221 222
        self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_A),
                                          True)
L
lijianshe02 已提交
223
        # GAN loss D_B(G_B(B))
L
LielinJiang 已提交
224 225
        self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_B),
                                          True)
L
lijianshe02 已提交
226 227 228 229 230 231 232
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B

L
lijianshe02 已提交
233 234
        self.losses['G_A_adv_loss'] = self.loss_G_A
        self.losses['G_B_adv_loss'] = self.loss_G_B
235

L
lijianshe02 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
        mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1)
        mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1)

        mask_A_lip_np = mask_A_lip.numpy().squeeze()
        mask_B_lip_np = mask_B_lip.numpy().squeeze()
        mask_A_lip_np, mask_B_lip_np, index_A_lip, index_B_lip = mask_preprocess(
            mask_A_lip_np, mask_B_lip_np)
        real_A = paddle.nn.clip((self.real_A + 1.0) / 2.0, 0.0, 1.0) * 255.0
        real_A_np = real_A.numpy().squeeze()
        real_B = paddle.nn.clip((self.real_B + 1.0) / 2.0, 0.0, 1.0) * 255.0
        real_B_np = real_B.numpy().squeeze()
        fake_A = paddle.nn.clip((self.fake_A + 1.0) / 2.0, 0.0, 1.0) * 255.0
        fake_A_np = fake_A.numpy().squeeze()
        fake_B = paddle.nn.clip((self.fake_B + 1.0) / 2.0, 0.0, 1.0) * 255.0
        fake_B_np = fake_B.numpy().squeeze()

        fake_match_lip_A = hisMatch(fake_A_np, real_B_np, mask_A_lip_np,
                                    mask_B_lip_np, index_A_lip)
        fake_match_lip_B = hisMatch(fake_B_np, real_A_np, mask_B_lip_np,
                                    mask_A_lip_np, index_B_lip)
        fake_match_lip_A = paddle.to_tensor(fake_match_lip_A)
        fake_match_lip_A.stop_gradient = True
        fake_match_lip_A = fake_match_lip_A.unsqueeze(0)
        fake_match_lip_B = paddle.to_tensor(fake_match_lip_B)
        fake_match_lip_B.stop_gradient = True
        fake_match_lip_B = fake_match_lip_B.unsqueeze(0)
        fake_A_lip_masked = fake_A * mask_A_lip
        fake_B_lip_masked = fake_B * mask_B_lip
        g_A_lip_loss_his = self.criterionL1(fake_A_lip_masked, fake_match_lip_A)
        g_B_lip_loss_his = self.criterionL1(fake_B_lip_masked, fake_match_lip_B)

        #skin
        mask_A_skin = self.mask_A_aug[:, 1].unsqueeze(1)
        mask_B_skin = self.mask_B_aug[:, 1].unsqueeze(1)

        mask_A_skin_np = mask_A_skin.numpy().squeeze()
        mask_B_skin_np = mask_B_skin.numpy().squeeze()
        mask_A_skin_np, mask_B_skin_np, index_A_skin, index_B_skin = mask_preprocess(
            mask_A_skin_np, mask_B_skin_np)

        fake_match_skin_A = hisMatch(fake_A_np, real_B_np, mask_A_skin_np,
                                     mask_B_skin_np, index_A_skin)
        fake_match_skin_B = hisMatch(fake_B_np, real_A_np, mask_B_skin_np,
                                     mask_A_skin_np, index_B_skin)
        fake_match_skin_A = paddle.to_tensor(fake_match_skin_A)
        fake_match_skin_A.stop_gradient = True
        fake_match_skin_A = fake_match_skin_A.unsqueeze(0)
        fake_match_skin_B = paddle.to_tensor(fake_match_skin_B)
        fake_match_skin_B.stop_gradient = True
        fake_match_skin_B = fake_match_skin_B.unsqueeze(0)
        fake_A_skin_masked = fake_A * mask_A_skin
        fake_B_skin_masked = fake_B * mask_B_skin
        g_A_skin_loss_his = self.criterionL1(fake_A_skin_masked,
                                             fake_match_skin_A)
        g_B_skin_loss_his = self.criterionL1(fake_B_skin_masked,
                                             fake_match_skin_B)

        #eye
        mask_A_eye = self.mask_A_aug[:, 2].unsqueeze(1)
        mask_B_eye = self.mask_B_aug[:, 2].unsqueeze(1)

        mask_A_eye_np = mask_A_eye.numpy().squeeze()
        mask_B_eye_np = mask_B_eye.numpy().squeeze()
        mask_A_eye_np, mask_B_eye_np, index_A_eye, index_B_eye = mask_preprocess(
            mask_A_eye_np, mask_B_eye_np)

        fake_match_eye_A = hisMatch(fake_A_np, real_B_np, mask_A_eye_np,
                                    mask_B_eye_np, index_A_eye)
        fake_match_eye_B = hisMatch(fake_B_np, real_A_np, mask_B_eye_np,
                                    mask_A_eye_np, index_B_eye)
        fake_match_eye_A = paddle.to_tensor(fake_match_eye_A)
        fake_match_eye_A.stop_gradient = True
        fake_match_eye_A = fake_match_eye_A.unsqueeze(0)
        fake_match_eye_B = paddle.to_tensor(fake_match_eye_B)
        fake_match_eye_B.stop_gradient = True
        fake_match_eye_B = fake_match_eye_B.unsqueeze(0)
        fake_A_eye_masked = fake_A * mask_A_eye
        fake_B_eye_masked = fake_B * mask_B_eye
        g_A_eye_loss_his = self.criterionL1(fake_A_eye_masked, fake_match_eye_A)
        g_B_eye_loss_his = self.criterionL1(fake_B_eye_masked, fake_match_eye_B)

        self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his +
318
                             g_A_skin_loss_his * 0.1) * 0.01
L
lijianshe02 已提交
319
        self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his +
320
                             g_B_skin_loss_his * 0.1) * 0.01
L
lijianshe02 已提交
321

L
lijianshe02 已提交
322 323
        self.losses['G_A_his_loss'] = self.loss_G_A_his
        self.losses['G_B_his_loss'] = self.loss_G_A_his
L
lijianshe02 已提交
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338

        #vgg loss
        vgg_s = self.vgg(self.real_A)
        vgg_s.stop_gradient = True
        vgg_fake_A = self.vgg(self.fake_A)
        self.loss_A_vgg = self.criterionL2(vgg_fake_A,
                                           vgg_s) * lambda_A * lambda_vgg

        vgg_r = self.vgg(self.real_B)
        vgg_r.stop_gradient = True
        vgg_fake_B = self.vgg(self.fake_B)
        self.loss_B_vgg = self.criterionL2(vgg_fake_B,
                                           vgg_r) * lambda_B * lambda_vgg

        self.loss_rec = (self.loss_cycle_A + self.loss_cycle_B +
339 340
                         self.loss_A_vgg + self.loss_B_vgg) * 0.2
        self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.2
L
lijianshe02 已提交
341

L
lijianshe02 已提交
342 343 344 345
        self.losses['G_A_vgg_loss'] = self.loss_A_vgg
        self.losses['G_B_vgg_loss'] = self.loss_B_vgg
        self.losses['G_rec_loss'] = self.loss_rec
        self.losses['G_idt_loss'] = self.loss_idt
346

L
lijianshe02 已提交
347 348 349 350 351 352
        # bg consistency loss
        mask_A_consis = paddle.cast(
            (self.mask_A == 0), dtype='float32') + paddle.cast(
                (self.mask_A == 10), dtype='float32') + paddle.cast(
                    (self.mask_A == 8), dtype='float32')
        mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1)
353 354
        self.loss_G_bg_consis = self.criterionL1(
            self.real_A * mask_A_consis, self.fake_A * mask_A_consis) * 0.1
L
lijianshe02 已提交
355 356 357 358 359 360 361 362 363 364 365 366

        # combined loss and calculate gradients

        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_rec + self.loss_idt + self.loss_G_A_his + self.loss_G_B_his + self.loss_G_bg_consis
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad(
L
LielinJiang 已提交
367
            [self.nets['netD_A'], self.nets['netD_B']],
L
lijianshe02 已提交
368 369 370
            False)  # Ds require no gradients when optimizing Gs
        # self.optimizer_G.clear_gradients() #zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
L
LielinJiang 已提交
371
        self.optimizers['optimizer_G'].minimize(
L
lijianshe02 已提交
372
            self.loss_G)  #step()       # update G_A and G_B's weights
L
LielinJiang 已提交
373
        self.optimizers['optimizer_G'].clear_gradients()
L
lijianshe02 已提交
374
        # D_A and D_B
L
LielinJiang 已提交
375
        self.set_requires_grad(self.nets['netD_A'], True)
L
lijianshe02 已提交
376 377
        # self.optimizer_D.clear_gradients() #zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()  # calculate gradients for D_A
L
LielinJiang 已提交
378
        self.optimizers['optimizer_DA'].minimize(
L
lijianshe02 已提交
379
            self.loss_D_A)  #step()  # update D_A and D_B's weights
L
LielinJiang 已提交
380 381
        self.optimizers['optimizer_DA'].clear_gradients()  #zero_g
        self.set_requires_grad(self.nets['netD_B'], True)
L
lijianshe02 已提交
382 383

        self.backward_D_B()  # calculate graidents for D_B
L
LielinJiang 已提交
384
        self.optimizers['optimizer_DB'].minimize(
L
lijianshe02 已提交
385
            self.loss_D_B)  #step()  # update D_A and D_B's weights
L
LielinJiang 已提交
386
        self.optimizers['optimizer_DB'].clear_gradients(
L
lijianshe02 已提交
387
        )  #zero_grad()   # set D_A and D_B's gradients to zero