makeup_model.py 17.5 KB
Newer Older
Q
qingqing01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
L
lijianshe02 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import os
L
LielinJiang 已提交
15
import numpy as np
L
lijianshe02 已提交
16

L
lijianshe02 已提交
17
import paddle
18

L
LielinJiang 已提交
19
from paddle.vision.models import vgg16
20
from paddle.utils.download import get_path_from_url
L
lijianshe02 已提交
21 22 23 24 25
from .base_model import BaseModel

from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
26
from .criterions import build_criterion
L
lijianshe02 已提交
27
from ..modules.init import init_weights
L
lijianshe02 已提交
28 29 30
from ..utils.image_pool import ImagePool
from ..utils.preprocess import *

31 32
VGGFACE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/vggface.pdparams'

L
lijianshe02 已提交
33 34 35 36

@MODELS.register()
class MakeupModel(BaseModel):
    """
L
LielinJiang 已提交
37
    PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf
L
lijianshe02 已提交
38
    """
39 40 41 42 43 44 45 46 47 48 49 50 51
    def __init__(self,
                 generator,
                 discriminator=None,
                 cycle_criterion=None,
                 idt_criterion=None,
                 gan_criterion=None,
                 l1_criterion=None,
                 l2_criterion=None,
                 pool_size=50,
                 direction='a2b',
                 lambda_a=10.,
                 lambda_b=10.,
                 is_train=True):
L
LielinJiang 已提交
52
        """Initialize the PSGAN class.
L
lijianshe02 已提交
53 54

        Parameters:
L
LielinJiang 已提交
55
            cfg (dict)-- config of model.
L
lijianshe02 已提交
56
        """
57 58 59 60
        super(MakeupModel, self).__init__()
        self.lambda_a = lambda_a
        self.lambda_b = lambda_b
        self.is_train = is_train
L
lijianshe02 已提交
61 62 63
        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
64
        self.nets['netG'] = build_generator(generator)
L
LielinJiang 已提交
65
        init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)
L
lijianshe02 已提交
66

L
LielinJiang 已提交
67
        if self.is_train:  # define discriminators
68
            vgg = vgg16(pretrained=False)
L
LielinJiang 已提交
69
            self.vgg = vgg.features
70 71 72 73 74
            cur_path = os.path.abspath(os.path.dirname(__file__))
            vgg_weight_path = get_path_from_url(VGGFACE_WEIGHT_URL, cur_path)
            param = paddle.load(vgg_weight_path)
            vgg.load_dict(param)

75 76
            self.nets['netD_A'] = build_discriminator(discriminator)
            self.nets['netD_B'] = build_discriminator(discriminator)
L
LielinJiang 已提交
77 78
            init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0)
            init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0)
L
lijianshe02 已提交
79

80 81 82 83
            # create image buffer to store previously generated images
            self.fake_A_pool = ImagePool(pool_size)
            self.fake_B_pool = ImagePool(pool_size)

L
lijianshe02 已提交
84
            # define loss functions
85 86 87 88 89 90 91 92 93 94 95 96
            if gan_criterion:
                self.gan_criterion = build_criterion(gan_criterion)
            if cycle_criterion:
                self.cycle_criterion = build_criterion(cycle_criterion)
            if idt_criterion:
                self.idt_criterion = build_criterion(idt_criterion)
            if l1_criterion:
                self.l1_criterion = build_criterion(l1_criterion)
            if l2_criterion:
                self.l2_criterion = build_criterion(l2_criterion)

    def setup_input(self, input):
L
lijianshe02 已提交
97 98
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

99
        Args:
L
lijianshe02 已提交
100 101 102 103 104 105 106 107 108 109 110
            input (dict): include the data itself and its metadata information.

        """
        self.real_A = paddle.to_tensor(input['image_A'])
        self.real_B = paddle.to_tensor(input['image_B'])
        self.c_m = paddle.to_tensor(input['consis_mask'])
        self.P_A = paddle.to_tensor(input['P_A'])
        self.P_B = paddle.to_tensor(input['P_B'])
        self.mask_A_aug = paddle.to_tensor(input['mask_A_aug'])
        self.mask_B_aug = paddle.to_tensor(input['mask_B_aug'])
        self.c_m_t = paddle.transpose(self.c_m, perm=[0, 2, 1])
L
LielinJiang 已提交
111
        if self.is_train:
L
lijianshe02 已提交
112 113 114 115 116 117 118
            self.mask_A = paddle.to_tensor(input['mask_A'])
            self.mask_B = paddle.to_tensor(input['mask_B'])
            self.c_m_idt_a = paddle.to_tensor(input['consis_mask_idt_A'])
            self.c_m_idt_b = paddle.to_tensor(input['consis_mask_idt_B'])

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
L
LielinJiang 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        self.fake_A, amm = self.nets['netG'](self.real_A, self.real_B, self.P_A,
                                             self.P_B, self.c_m,
                                             self.mask_A_aug,
                                             self.mask_B_aug)  # G_A(A)
        self.fake_B, _ = self.nets['netG'](self.real_B, self.real_A, self.P_B,
                                           self.P_A, self.c_m_t,
                                           self.mask_A_aug,
                                           self.mask_B_aug)  # G_A(A)
        self.rec_A, _ = self.nets['netG'](self.fake_A, self.real_A, self.P_A,
                                          self.P_A, self.c_m_idt_a,
                                          self.mask_A_aug,
                                          self.mask_B_aug)  # G_A(A)
        self.rec_B, _ = self.nets['netG'](self.fake_B, self.real_B, self.P_B,
                                          self.P_B, self.c_m_idt_b,
                                          self.mask_A_aug,
                                          self.mask_B_aug)  # G_A(A)

        # visual
        self.visual_items['real_A'] = self.real_A
        self.visual_items['fake_B'] = self.fake_B
        self.visual_items['rec_A'] = self.rec_A
        self.visual_items['real_B'] = self.real_B
        self.visual_items['fake_A'] = self.fake_A
        self.visual_items['rec_B'] = self.rec_B
L
lijianshe02 已提交
143

144 145 146 147 148 149 150
    def test(self, input):
        with paddle.no_grad():
            return self.nets['netG'](input['image_A'], input['image_B'],
                                     input['P_A'], input['P_B'],
                                     input['consis_mask'], input['mask_A_aug'],
                                     input['mask_B_aug'])

L
lijianshe02 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163
    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
164
        loss_D_real = self.gan_criterion(pred_real, True)
L
lijianshe02 已提交
165 166
        # Fake
        pred_fake = netD(fake.detach())
167
        loss_D_fake = self.gan_criterion(pred_fake, False)
L
lijianshe02 已提交
168 169 170 171 172 173 174 175
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
L
LielinJiang 已提交
176 177
        self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
                                              fake_B)
L
lijianshe02 已提交
178
        self.losses['D_A_loss'] = self.loss_D_A
L
lijianshe02 已提交
179 180 181 182

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
L
LielinJiang 已提交
183 184
        self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
                                              fake_A)
L
lijianshe02 已提交
185
        self.losses['D_B_loss'] = self.loss_D_B
L
lijianshe02 已提交
186 187 188

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
L
LielinJiang 已提交
189

190 191
        lambda_A = self.lambda_a
        lambda_B = self.lambda_b
L
lijianshe02 已提交
192
        lambda_vgg = 5e-3
193

L
lijianshe02 已提交
194
        # Identity loss
195
        if self.idt_criterion:
L
LielinJiang 已提交
196 197 198 199
            self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A,
                                              self.P_A, self.P_A,
                                              self.c_m_idt_a, self.mask_A_aug,
                                              self.mask_B_aug)  # G_A(A)
200 201
            self.loss_idt_A = self.idt_criterion(self.idt_A,
                                                 self.real_A) * lambda_A
L
LielinJiang 已提交
202 203 204 205
            self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B,
                                              self.P_B, self.P_B,
                                              self.c_m_idt_b, self.mask_A_aug,
                                              self.mask_B_aug)  # G_A(A)
206 207
            self.loss_idt_B = self.idt_criterion(self.idt_B,
                                                 self.real_B) * lambda_B
L
LielinJiang 已提交
208 209 210 211

            # visual
            self.visual_items['idt_A'] = self.idt_A
            self.visual_items['idt_B'] = self.idt_B
L
lijianshe02 已提交
212 213 214 215 216
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
217 218
        self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_A),
                                           True)
L
lijianshe02 已提交
219
        # GAN loss D_B(G_B(B))
220 221
        self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_B),
                                           True)
L
lijianshe02 已提交
222
        # Forward cycle loss || G_B(G_A(A)) - A||
223 224
        self.loss_cycle_A = self.cycle_criterion(self.rec_A,
                                                 self.real_A) * lambda_A
L
lijianshe02 已提交
225
        # Backward cycle loss || G_A(G_B(B)) - B||
226 227
        self.loss_cycle_B = self.cycle_criterion(self.rec_B,
                                                 self.real_B) * lambda_B
L
lijianshe02 已提交
228

L
lijianshe02 已提交
229 230
        self.losses['G_A_adv_loss'] = self.loss_G_A
        self.losses['G_B_adv_loss'] = self.loss_G_B
231

L
lijianshe02 已提交
232 233 234 235 236 237 238
        mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1)
        mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1)

        mask_A_lip_np = mask_A_lip.numpy().squeeze()
        mask_B_lip_np = mask_B_lip.numpy().squeeze()
        mask_A_lip_np, mask_B_lip_np, index_A_lip, index_B_lip = mask_preprocess(
            mask_A_lip_np, mask_B_lip_np)
L
lijianshe02 已提交
239
        real_A = paddle.clip((self.real_A + 1.0) / 2.0, 0.0, 1.0) * 255.0
L
lijianshe02 已提交
240
        real_A_np = real_A.numpy().squeeze()
L
lijianshe02 已提交
241
        real_B = paddle.clip((self.real_B + 1.0) / 2.0, 0.0, 1.0) * 255.0
L
lijianshe02 已提交
242
        real_B_np = real_B.numpy().squeeze()
L
lijianshe02 已提交
243
        fake_A = paddle.clip((self.fake_A + 1.0) / 2.0, 0.0, 1.0) * 255.0
L
lijianshe02 已提交
244
        fake_A_np = fake_A.numpy().squeeze()
L
lijianshe02 已提交
245
        fake_B = paddle.clip((self.fake_B + 1.0) / 2.0, 0.0, 1.0) * 255.0
L
lijianshe02 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259
        fake_B_np = fake_B.numpy().squeeze()

        fake_match_lip_A = hisMatch(fake_A_np, real_B_np, mask_A_lip_np,
                                    mask_B_lip_np, index_A_lip)
        fake_match_lip_B = hisMatch(fake_B_np, real_A_np, mask_B_lip_np,
                                    mask_A_lip_np, index_B_lip)
        fake_match_lip_A = paddle.to_tensor(fake_match_lip_A)
        fake_match_lip_A.stop_gradient = True
        fake_match_lip_A = fake_match_lip_A.unsqueeze(0)
        fake_match_lip_B = paddle.to_tensor(fake_match_lip_B)
        fake_match_lip_B.stop_gradient = True
        fake_match_lip_B = fake_match_lip_B.unsqueeze(0)
        fake_A_lip_masked = fake_A * mask_A_lip
        fake_B_lip_masked = fake_B * mask_B_lip
260 261 262 263
        g_A_lip_loss_his = self.l1_criterion(fake_A_lip_masked,
                                             fake_match_lip_A)
        g_B_lip_loss_his = self.l1_criterion(fake_B_lip_masked,
                                             fake_match_lip_B)
L
lijianshe02 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285

        #skin
        mask_A_skin = self.mask_A_aug[:, 1].unsqueeze(1)
        mask_B_skin = self.mask_B_aug[:, 1].unsqueeze(1)

        mask_A_skin_np = mask_A_skin.numpy().squeeze()
        mask_B_skin_np = mask_B_skin.numpy().squeeze()
        mask_A_skin_np, mask_B_skin_np, index_A_skin, index_B_skin = mask_preprocess(
            mask_A_skin_np, mask_B_skin_np)

        fake_match_skin_A = hisMatch(fake_A_np, real_B_np, mask_A_skin_np,
                                     mask_B_skin_np, index_A_skin)
        fake_match_skin_B = hisMatch(fake_B_np, real_A_np, mask_B_skin_np,
                                     mask_A_skin_np, index_B_skin)
        fake_match_skin_A = paddle.to_tensor(fake_match_skin_A)
        fake_match_skin_A.stop_gradient = True
        fake_match_skin_A = fake_match_skin_A.unsqueeze(0)
        fake_match_skin_B = paddle.to_tensor(fake_match_skin_B)
        fake_match_skin_B.stop_gradient = True
        fake_match_skin_B = fake_match_skin_B.unsqueeze(0)
        fake_A_skin_masked = fake_A * mask_A_skin
        fake_B_skin_masked = fake_B * mask_B_skin
286 287 288 289
        g_A_skin_loss_his = self.l1_criterion(fake_A_skin_masked,
                                              fake_match_skin_A)
        g_B_skin_loss_his = self.l1_criterion(fake_B_skin_masked,
                                              fake_match_skin_B)
L
lijianshe02 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311

        #eye
        mask_A_eye = self.mask_A_aug[:, 2].unsqueeze(1)
        mask_B_eye = self.mask_B_aug[:, 2].unsqueeze(1)

        mask_A_eye_np = mask_A_eye.numpy().squeeze()
        mask_B_eye_np = mask_B_eye.numpy().squeeze()
        mask_A_eye_np, mask_B_eye_np, index_A_eye, index_B_eye = mask_preprocess(
            mask_A_eye_np, mask_B_eye_np)

        fake_match_eye_A = hisMatch(fake_A_np, real_B_np, mask_A_eye_np,
                                    mask_B_eye_np, index_A_eye)
        fake_match_eye_B = hisMatch(fake_B_np, real_A_np, mask_B_eye_np,
                                    mask_A_eye_np, index_B_eye)
        fake_match_eye_A = paddle.to_tensor(fake_match_eye_A)
        fake_match_eye_A.stop_gradient = True
        fake_match_eye_A = fake_match_eye_A.unsqueeze(0)
        fake_match_eye_B = paddle.to_tensor(fake_match_eye_B)
        fake_match_eye_B.stop_gradient = True
        fake_match_eye_B = fake_match_eye_B.unsqueeze(0)
        fake_A_eye_masked = fake_A * mask_A_eye
        fake_B_eye_masked = fake_B * mask_B_eye
312 313 314 315
        g_A_eye_loss_his = self.l1_criterion(fake_A_eye_masked,
                                             fake_match_eye_A)
        g_B_eye_loss_his = self.l1_criterion(fake_B_eye_masked,
                                             fake_match_eye_B)
L
lijianshe02 已提交
316 317

        self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his +
L
lijianshe02 已提交
318
                             g_A_skin_loss_his * 0.1) * 0.1
L
lijianshe02 已提交
319
        self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his +
L
lijianshe02 已提交
320
                             g_B_skin_loss_his * 0.1) * 0.1
L
lijianshe02 已提交
321

L
lijianshe02 已提交
322
        self.losses['G_A_his_loss'] = self.loss_G_A_his
323
        self.losses['G_B_his_loss'] = self.loss_G_B_his
L
lijianshe02 已提交
324 325 326 327 328

        #vgg loss
        vgg_s = self.vgg(self.real_A)
        vgg_s.stop_gradient = True
        vgg_fake_A = self.vgg(self.fake_A)
329 330
        self.loss_A_vgg = self.l2_criterion(vgg_fake_A,
                                            vgg_s) * lambda_A * lambda_vgg
L
lijianshe02 已提交
331 332 333 334

        vgg_r = self.vgg(self.real_B)
        vgg_r.stop_gradient = True
        vgg_fake_B = self.vgg(self.fake_B)
335 336
        self.loss_B_vgg = self.l2_criterion(vgg_fake_B,
                                            vgg_r) * lambda_B * lambda_vgg
L
lijianshe02 已提交
337

L
lijianshe02 已提交
338 339 340
        self.loss_rec = (self.loss_cycle_A * 0.2 + self.loss_cycle_B * 0.2 +
                         self.loss_A_vgg + self.loss_B_vgg) * 0.5
        self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.1
L
lijianshe02 已提交
341

L
lijianshe02 已提交
342 343 344 345
        self.losses['G_A_vgg_loss'] = self.loss_A_vgg
        self.losses['G_B_vgg_loss'] = self.loss_B_vgg
        self.losses['G_rec_loss'] = self.loss_rec
        self.losses['G_idt_loss'] = self.loss_idt
346

L
lijianshe02 已提交
347 348 349 350 351 352
        # bg consistency loss
        mask_A_consis = paddle.cast(
            (self.mask_A == 0), dtype='float32') + paddle.cast(
                (self.mask_A == 10), dtype='float32') + paddle.cast(
                    (self.mask_A == 8), dtype='float32')
        mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1)
353
        self.loss_G_bg_consis = self.l1_criterion(
354
            self.real_A * mask_A_consis, self.fake_A * mask_A_consis) * 0.1
L
lijianshe02 已提交
355 356 357 358 359

        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_rec + self.loss_idt + self.loss_G_A_his + self.loss_G_B_his + self.loss_G_bg_consis
        self.loss_G.backward()

360
    def train_iter(self, optimizers=None):
L
lijianshe02 已提交
361 362 363 364 365
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad(
L
LielinJiang 已提交
366
            [self.nets['netD_A'], self.nets['netD_B']],
L
lijianshe02 已提交
367 368 369
            False)  # Ds require no gradients when optimizing Gs
        # self.optimizer_G.clear_gradients() #zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
L
LielinJiang 已提交
370
        self.optimizers['optimizer_G'].minimize(
L
lijianshe02 已提交
371
            self.loss_G)  #step()       # update G_A and G_B's weights
L
LielinJiang 已提交
372
        self.optimizers['optimizer_G'].clear_gradients()
L
lijianshe02 已提交
373
        # D_A and D_B
L
LielinJiang 已提交
374
        self.set_requires_grad(self.nets['netD_A'], True)
L
lijianshe02 已提交
375 376
        # self.optimizer_D.clear_gradients() #zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()  # calculate gradients for D_A
L
LielinJiang 已提交
377
        self.optimizers['optimizer_DA'].minimize(
L
lijianshe02 已提交
378
            self.loss_D_A)  #step()  # update D_A and D_B's weights
L
LielinJiang 已提交
379 380
        self.optimizers['optimizer_DA'].clear_gradients()  #zero_g
        self.set_requires_grad(self.nets['netD_B'], True)
L
lijianshe02 已提交
381 382

        self.backward_D_B()  # calculate graidents for D_B
L
LielinJiang 已提交
383
        self.optimizers['optimizer_DB'].minimize(
L
lijianshe02 已提交
384
            self.loss_D_B)  #step()  # update D_A and D_B's weights
385
        self.optimizers['optimizer_DB'].clear_gradients()