makeup_model.py 18.0 KB
Newer Older
Q
qingqing01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
L
lijianshe02 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import os
L
LielinJiang 已提交
15
import numpy as np
L
lijianshe02 已提交
16

L
lijianshe02 已提交
17 18 19
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
L
LielinJiang 已提交
20
from paddle.vision.models import vgg16
21
from paddle.utils.download import get_path_from_url
L
lijianshe02 已提交
22 23 24 25 26 27
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
L
lijianshe02 已提交
28
from ..modules.init import init_weights
L
lijianshe02 已提交
29 30 31 32 33
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool
from ..utils.preprocess import *
from ..datasets.makeup_dataset import MakeupDataset

34 35
VGGFACE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/vggface.pdparams'

L
lijianshe02 已提交
36 37 38 39

@MODELS.register()
class MakeupModel(BaseModel):
    """
L
LielinJiang 已提交
40
    PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf
L
lijianshe02 已提交
41
    """
L
LielinJiang 已提交
42 43
    def __init__(self, cfg):
        """Initialize the PSGAN class.
L
lijianshe02 已提交
44 45

        Parameters:
L
LielinJiang 已提交
46
            cfg (dict)-- config of model.
L
lijianshe02 已提交
47
        """
L
LielinJiang 已提交
48
        super(MakeupModel, self).__init__(cfg)
L
lijianshe02 已提交
49 50 51 52

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
L
LielinJiang 已提交
53 54
        self.nets['netG'] = build_generator(cfg.model.generator)
        init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)
L
lijianshe02 已提交
55

L
LielinJiang 已提交
56
        if self.is_train:  # define discriminators
57
            vgg = vgg16(pretrained=False)
L
LielinJiang 已提交
58
            self.vgg = vgg.features
59 60 61 62 63
            cur_path = os.path.abspath(os.path.dirname(__file__))
            vgg_weight_path = get_path_from_url(VGGFACE_WEIGHT_URL, cur_path)
            param = paddle.load(vgg_weight_path)
            vgg.load_dict(param)

L
LielinJiang 已提交
64 65 66 67
            self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
            self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
            init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0)
            init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0)
L
lijianshe02 已提交
68 69

            self.fake_A_pool = ImagePool(
L
LielinJiang 已提交
70
                cfg.dataset.train.pool_size
L
lijianshe02 已提交
71 72
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
L
LielinJiang 已提交
73
                cfg.dataset.train.pool_size
L
lijianshe02 已提交
74 75 76
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = GANLoss(
L
LielinJiang 已提交
77
                cfg.model.gan_mode)  #.to(self.device)  # define GAN loss.
L
lijianshe02 已提交
78 79 80 81 82 83
            self.criterionCycle = paddle.nn.L1Loss()
            self.criterionIdt = paddle.nn.L1Loss()
            self.criterionL1 = paddle.nn.L1Loss()
            self.criterionL2 = paddle.nn.MSELoss()

            self.build_lr_scheduler()
L
LielinJiang 已提交
84 85
            self.optimizers['optimizer_G'] = build_optimizer(
                cfg.optimizer,
L
lijianshe02 已提交
86
                self.lr_scheduler,
L
LielinJiang 已提交
87 88 89
                parameter_list=self.nets['netG'].parameters())
            self.optimizers['optimizer_DA'] = build_optimizer(
                cfg.optimizer,
L
lijianshe02 已提交
90
                self.lr_scheduler,
L
LielinJiang 已提交
91 92 93
                parameter_list=self.nets['netD_A'].parameters())
            self.optimizers['optimizer_DB'] = build_optimizer(
                cfg.optimizer,
L
lijianshe02 已提交
94
                self.lr_scheduler,
L
LielinJiang 已提交
95
                parameter_list=self.nets['netD_B'].parameters())
L
lijianshe02 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        self.real_A = paddle.to_tensor(input['image_A'])
        self.real_B = paddle.to_tensor(input['image_B'])
        self.c_m = paddle.to_tensor(input['consis_mask'])
        self.P_A = paddle.to_tensor(input['P_A'])
        self.P_B = paddle.to_tensor(input['P_B'])
        self.mask_A_aug = paddle.to_tensor(input['mask_A_aug'])
        self.mask_B_aug = paddle.to_tensor(input['mask_B_aug'])
        self.c_m_t = paddle.transpose(self.c_m, perm=[0, 2, 1])
L
LielinJiang 已提交
113
        if self.is_train:
L
lijianshe02 已提交
114 115 116 117 118 119 120
            self.mask_A = paddle.to_tensor(input['mask_A'])
            self.mask_B = paddle.to_tensor(input['mask_B'])
            self.c_m_idt_a = paddle.to_tensor(input['consis_mask_idt_A'])
            self.c_m_idt_b = paddle.to_tensor(input['consis_mask_idt_B'])

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        self.fake_A, amm = self.nets['netG'](self.real_A, self.real_B, self.P_A,
                                             self.P_B, self.c_m,
                                             self.mask_A_aug,
                                             self.mask_B_aug)  # G_A(A)
        self.fake_B, _ = self.nets['netG'](self.real_B, self.real_A, self.P_B,
                                           self.P_A, self.c_m_t,
                                           self.mask_A_aug,
                                           self.mask_B_aug)  # G_A(A)
        self.rec_A, _ = self.nets['netG'](self.fake_A, self.real_A, self.P_A,
                                          self.P_A, self.c_m_idt_a,
                                          self.mask_A_aug,
                                          self.mask_B_aug)  # G_A(A)
        self.rec_B, _ = self.nets['netG'](self.fake_B, self.real_B, self.P_B,
                                          self.P_B, self.c_m_idt_b,
                                          self.mask_A_aug,
                                          self.mask_B_aug)  # G_A(A)

        # visual
        self.visual_items['real_A'] = self.real_A
        self.visual_items['fake_B'] = self.fake_B
        self.visual_items['rec_A'] = self.rec_A
        self.visual_items['real_B'] = self.real_B
        self.visual_items['fake_A'] = self.fake_A
        self.visual_items['rec_B'] = self.rec_B
L
lijianshe02 已提交
145 146 147 148 149

    def forward_test(self, input):
        '''
        not implement now
        '''
L
LielinJiang 已提交
150 151 152 153
        return self.nets['netG'](input['image_A'], input['image_B'],
                                 input['P_A'], input['P_B'],
                                 input['consis_mask'], input['mask_A_aug'],
                                 input['mask_B_aug'])
L
lijianshe02 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188

    def test(self, input):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with paddle.no_grad():
            return self.forward_test(input)

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
L
LielinJiang 已提交
189 190
        self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
                                              fake_B)
L
lijianshe02 已提交
191
        self.losses['D_A_loss'] = self.loss_D_A
L
lijianshe02 已提交
192 193 194 195

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
L
LielinJiang 已提交
196 197
        self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
                                              fake_A)
L
lijianshe02 已提交
198
        self.losses['D_B_loss'] = self.loss_D_B
L
lijianshe02 已提交
199 200 201

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
L
LielinJiang 已提交
202 203 204 205

        lambda_idt = self.cfg.lambda_identity
        lambda_A = self.cfg.lambda_A
        lambda_B = self.cfg.lambda_B
L
lijianshe02 已提交
206 207 208
        lambda_vgg = 5e-3
        # Identity loss
        if lambda_idt > 0:
L
LielinJiang 已提交
209 210 211 212
            self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A,
                                              self.P_A, self.P_A,
                                              self.c_m_idt_a, self.mask_A_aug,
                                              self.mask_B_aug)  # G_A(A)
L
lijianshe02 已提交
213 214
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_A) * lambda_A * lambda_idt
L
LielinJiang 已提交
215 216 217 218
            self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B,
                                              self.P_B, self.P_B,
                                              self.c_m_idt_b, self.mask_A_aug,
                                              self.mask_B_aug)  # G_A(A)
L
lijianshe02 已提交
219 220
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_B) * lambda_B * lambda_idt
L
LielinJiang 已提交
221 222 223 224

            # visual
            self.visual_items['idt_A'] = self.idt_A
            self.visual_items['idt_B'] = self.idt_B
L
lijianshe02 已提交
225 226 227 228 229
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
L
LielinJiang 已提交
230 231
        self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_A),
                                          True)
L
lijianshe02 已提交
232
        # GAN loss D_B(G_B(B))
L
LielinJiang 已提交
233 234
        self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_B),
                                          True)
L
lijianshe02 已提交
235 236 237 238 239 240 241
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B

L
lijianshe02 已提交
242 243
        self.losses['G_A_adv_loss'] = self.loss_G_A
        self.losses['G_B_adv_loss'] = self.loss_G_B
244

L
lijianshe02 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1)
        mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1)

        mask_A_lip_np = mask_A_lip.numpy().squeeze()
        mask_B_lip_np = mask_B_lip.numpy().squeeze()
        mask_A_lip_np, mask_B_lip_np, index_A_lip, index_B_lip = mask_preprocess(
            mask_A_lip_np, mask_B_lip_np)
        real_A = paddle.nn.clip((self.real_A + 1.0) / 2.0, 0.0, 1.0) * 255.0
        real_A_np = real_A.numpy().squeeze()
        real_B = paddle.nn.clip((self.real_B + 1.0) / 2.0, 0.0, 1.0) * 255.0
        real_B_np = real_B.numpy().squeeze()
        fake_A = paddle.nn.clip((self.fake_A + 1.0) / 2.0, 0.0, 1.0) * 255.0
        fake_A_np = fake_A.numpy().squeeze()
        fake_B = paddle.nn.clip((self.fake_B + 1.0) / 2.0, 0.0, 1.0) * 255.0
        fake_B_np = fake_B.numpy().squeeze()

        fake_match_lip_A = hisMatch(fake_A_np, real_B_np, mask_A_lip_np,
                                    mask_B_lip_np, index_A_lip)
        fake_match_lip_B = hisMatch(fake_B_np, real_A_np, mask_B_lip_np,
                                    mask_A_lip_np, index_B_lip)
        fake_match_lip_A = paddle.to_tensor(fake_match_lip_A)
        fake_match_lip_A.stop_gradient = True
        fake_match_lip_A = fake_match_lip_A.unsqueeze(0)
        fake_match_lip_B = paddle.to_tensor(fake_match_lip_B)
        fake_match_lip_B.stop_gradient = True
        fake_match_lip_B = fake_match_lip_B.unsqueeze(0)
        fake_A_lip_masked = fake_A * mask_A_lip
        fake_B_lip_masked = fake_B * mask_B_lip
        g_A_lip_loss_his = self.criterionL1(fake_A_lip_masked, fake_match_lip_A)
        g_B_lip_loss_his = self.criterionL1(fake_B_lip_masked, fake_match_lip_B)

        #skin
        mask_A_skin = self.mask_A_aug[:, 1].unsqueeze(1)
        mask_B_skin = self.mask_B_aug[:, 1].unsqueeze(1)

        mask_A_skin_np = mask_A_skin.numpy().squeeze()
        mask_B_skin_np = mask_B_skin.numpy().squeeze()
        mask_A_skin_np, mask_B_skin_np, index_A_skin, index_B_skin = mask_preprocess(
            mask_A_skin_np, mask_B_skin_np)

        fake_match_skin_A = hisMatch(fake_A_np, real_B_np, mask_A_skin_np,
                                     mask_B_skin_np, index_A_skin)
        fake_match_skin_B = hisMatch(fake_B_np, real_A_np, mask_B_skin_np,
                                     mask_A_skin_np, index_B_skin)
        fake_match_skin_A = paddle.to_tensor(fake_match_skin_A)
        fake_match_skin_A.stop_gradient = True
        fake_match_skin_A = fake_match_skin_A.unsqueeze(0)
        fake_match_skin_B = paddle.to_tensor(fake_match_skin_B)
        fake_match_skin_B.stop_gradient = True
        fake_match_skin_B = fake_match_skin_B.unsqueeze(0)
        fake_A_skin_masked = fake_A * mask_A_skin
        fake_B_skin_masked = fake_B * mask_B_skin
        g_A_skin_loss_his = self.criterionL1(fake_A_skin_masked,
                                             fake_match_skin_A)
        g_B_skin_loss_his = self.criterionL1(fake_B_skin_masked,
                                             fake_match_skin_B)

        #eye
        mask_A_eye = self.mask_A_aug[:, 2].unsqueeze(1)
        mask_B_eye = self.mask_B_aug[:, 2].unsqueeze(1)

        mask_A_eye_np = mask_A_eye.numpy().squeeze()
        mask_B_eye_np = mask_B_eye.numpy().squeeze()
        mask_A_eye_np, mask_B_eye_np, index_A_eye, index_B_eye = mask_preprocess(
            mask_A_eye_np, mask_B_eye_np)

        fake_match_eye_A = hisMatch(fake_A_np, real_B_np, mask_A_eye_np,
                                    mask_B_eye_np, index_A_eye)
        fake_match_eye_B = hisMatch(fake_B_np, real_A_np, mask_B_eye_np,
                                    mask_A_eye_np, index_B_eye)
        fake_match_eye_A = paddle.to_tensor(fake_match_eye_A)
        fake_match_eye_A.stop_gradient = True
        fake_match_eye_A = fake_match_eye_A.unsqueeze(0)
        fake_match_eye_B = paddle.to_tensor(fake_match_eye_B)
        fake_match_eye_B.stop_gradient = True
        fake_match_eye_B = fake_match_eye_B.unsqueeze(0)
        fake_A_eye_masked = fake_A * mask_A_eye
        fake_B_eye_masked = fake_B * mask_B_eye
        g_A_eye_loss_his = self.criterionL1(fake_A_eye_masked, fake_match_eye_A)
        g_B_eye_loss_his = self.criterionL1(fake_B_eye_masked, fake_match_eye_B)

        self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his +
L
lijianshe02 已提交
327
                             g_A_skin_loss_his * 0.1) * 0.1
L
lijianshe02 已提交
328
        self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his +
L
lijianshe02 已提交
329
                             g_B_skin_loss_his * 0.1) * 0.1
L
lijianshe02 已提交
330

L
lijianshe02 已提交
331
        self.losses['G_A_his_loss'] = self.loss_G_A_his
332
        self.losses['G_B_his_loss'] = self.loss_G_B_his
L
lijianshe02 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346

        #vgg loss
        vgg_s = self.vgg(self.real_A)
        vgg_s.stop_gradient = True
        vgg_fake_A = self.vgg(self.fake_A)
        self.loss_A_vgg = self.criterionL2(vgg_fake_A,
                                           vgg_s) * lambda_A * lambda_vgg

        vgg_r = self.vgg(self.real_B)
        vgg_r.stop_gradient = True
        vgg_fake_B = self.vgg(self.fake_B)
        self.loss_B_vgg = self.criterionL2(vgg_fake_B,
                                           vgg_r) * lambda_B * lambda_vgg

L
lijianshe02 已提交
347 348 349
        self.loss_rec = (self.loss_cycle_A * 0.2 + self.loss_cycle_B * 0.2 +
                         self.loss_A_vgg + self.loss_B_vgg) * 0.5
        self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.1
L
lijianshe02 已提交
350

L
lijianshe02 已提交
351 352 353 354
        self.losses['G_A_vgg_loss'] = self.loss_A_vgg
        self.losses['G_B_vgg_loss'] = self.loss_B_vgg
        self.losses['G_rec_loss'] = self.loss_rec
        self.losses['G_idt_loss'] = self.loss_idt
355

L
lijianshe02 已提交
356 357 358 359 360 361
        # bg consistency loss
        mask_A_consis = paddle.cast(
            (self.mask_A == 0), dtype='float32') + paddle.cast(
                (self.mask_A == 10), dtype='float32') + paddle.cast(
                    (self.mask_A == 8), dtype='float32')
        mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1)
362 363
        self.loss_G_bg_consis = self.criterionL1(
            self.real_A * mask_A_consis, self.fake_A * mask_A_consis) * 0.1
L
lijianshe02 已提交
364 365 366 367 368 369 370 371 372 373 374 375

        # combined loss and calculate gradients

        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_rec + self.loss_idt + self.loss_G_A_his + self.loss_G_B_his + self.loss_G_bg_consis
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad(
L
LielinJiang 已提交
376
            [self.nets['netD_A'], self.nets['netD_B']],
L
lijianshe02 已提交
377 378 379
            False)  # Ds require no gradients when optimizing Gs
        # self.optimizer_G.clear_gradients() #zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
L
LielinJiang 已提交
380
        self.optimizers['optimizer_G'].minimize(
L
lijianshe02 已提交
381
            self.loss_G)  #step()       # update G_A and G_B's weights
L
LielinJiang 已提交
382
        self.optimizers['optimizer_G'].clear_gradients()
L
lijianshe02 已提交
383
        # D_A and D_B
L
LielinJiang 已提交
384
        self.set_requires_grad(self.nets['netD_A'], True)
L
lijianshe02 已提交
385 386
        # self.optimizer_D.clear_gradients() #zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()  # calculate gradients for D_A
L
LielinJiang 已提交
387
        self.optimizers['optimizer_DA'].minimize(
L
lijianshe02 已提交
388
            self.loss_D_A)  #step()  # update D_A and D_B's weights
L
LielinJiang 已提交
389 390
        self.optimizers['optimizer_DA'].clear_gradients()  #zero_g
        self.set_requires_grad(self.nets['netD_B'], True)
L
lijianshe02 已提交
391 392

        self.backward_D_B()  # calculate graidents for D_B
L
LielinJiang 已提交
393
        self.optimizers['optimizer_DB'].minimize(
L
lijianshe02 已提交
394
            self.loss_D_B)  #step()  # update D_A and D_B's weights
L
LielinJiang 已提交
395
        self.optimizers['optimizer_DB'].clear_gradients(
L
lijianshe02 已提交
396
        )  #zero_grad()   # set D_A and D_B's gradients to zero