未验证 提交 6c75d650 编写于 作者: L lijianshe02 提交者: GitHub

Merge pull request #39 from lijianshe02/master

refine loss related code and align to the newest develop api
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import argparse import argparse
import paddle import paddle
from ppgan.first_order_predictor import FirstOrderPredictor from ppgan.apps.first_order_predictor import FirstOrderPredictor
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None, help="path to config") parser.add_argument("--config", default=None, help="path to config")
......
...@@ -26,7 +26,7 @@ class BaseModel(ABC): ...@@ -26,7 +26,7 @@ class BaseModel(ABC):
When creating your custom class, you need to implement your own initialization. When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)> In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists: Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.losses (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training. -- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save. -- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
...@@ -37,7 +37,7 @@ class BaseModel(ABC): ...@@ -37,7 +37,7 @@ class BaseModel(ABC):
opt.output_dir, opt.output_dir,
opt.model.name) # save all the checkpoints to save_dir opt.model.name) # save all the checkpoints to save_dir
self.loss_names = [] self.losses = OrderedDict()
self.model_names = [] self.model_names = []
self.visual_names = [] self.visual_names = []
self.optimizers = [] self.optimizers = []
...@@ -115,13 +115,7 @@ class BaseModel(ABC): ...@@ -115,13 +115,7 @@ class BaseModel(ABC):
def get_current_losses(self): def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict() return self.losses
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(
getattr(self, 'loss_' + name)
) # float(...) works for both scalar tensor and float number
return errors_ret
def set_requires_grad(self, nets, requires_grad=False): def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
......
...@@ -31,10 +31,6 @@ class CycleGANModel(BaseModel): ...@@ -31,10 +31,6 @@ class CycleGANModel(BaseModel):
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
""" """
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = [
'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B'] visual_names_B = ['real_B', 'fake_A', 'rec_B']
...@@ -165,11 +161,13 @@ class CycleGANModel(BaseModel): ...@@ -165,11 +161,13 @@ class CycleGANModel(BaseModel):
"""Calculate GAN loss for discriminator D_A""" """Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B) fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self): def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B""" """Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A) fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self): def backward_G(self):
"""Calculate the loss for generators G_A and G_B""" """Calculate the loss for generators G_A and G_B"""
...@@ -200,6 +198,13 @@ class CycleGANModel(BaseModel): ...@@ -200,6 +198,13 @@ class CycleGANModel(BaseModel):
# Backward cycle loss || G_A(G_B(B)) - B|| # Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B self.real_B) * lambda_B
self.losses['G_idt_A_loss'] = self.loss_idt_A
self.losses['G_idt_B_loss'] = self.loss_idt_B
self.losses['G_A_adv_loss'] = self.loss_G_A
self.losses['G_B_adv_loss'] = self.loss_G_B
self.losses['G_A_cycle_loss'] = self.loss_cycle_A
self.losses['G_B_cycle_loss'] = self.loss_cycle_B
# combined loss and calculate gradients # combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
......
...@@ -47,16 +47,16 @@ class GANLoss(nn.Layer): ...@@ -47,16 +47,16 @@ class GANLoss(nn.Layer):
""" """
if target_is_real: if target_is_real:
if not hasattr(self, 'target_real_tensor'): if not hasattr(self, 'target_real_tensor'):
self.target_real_tensor = paddle.fill_constant( self.target_real_tensor = paddle.full(
shape=paddle.shape(prediction), shape=paddle.shape(prediction),
value=self.target_real_label, fill_value=self.target_real_label,
dtype='float32') dtype='float32')
target_tensor = self.target_real_tensor target_tensor = self.target_real_tensor
else: else:
if not hasattr(self, 'target_fake_tensor'): if not hasattr(self, 'target_fake_tensor'):
self.target_fake_tensor = paddle.fill_constant( self.target_fake_tensor = paddle.full(
shape=paddle.shape(prediction), shape=paddle.shape(prediction),
value=self.target_fake_label, fill_value=self.target_fake_label,
dtype='float32') dtype='float32')
target_tensor = self.target_fake_tensor target_tensor = self.target_fake_tensor
......
...@@ -50,19 +50,6 @@ class MakeupModel(BaseModel): ...@@ -50,19 +50,6 @@ class MakeupModel(BaseModel):
""" """
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = [
'D_A',
'G_A',
'rec',
'idt',
'D_B',
'G_B',
'G_A_his',
'G_B_his',
'G_bg_consis',
'A_vgg',
'B_vgg',
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_A', 'rec_A'] visual_names_A = ['real_A', 'fake_A', 'rec_A']
visual_names_B = ['real_B', 'fake_B', 'rec_B'] visual_names_B = ['real_B', 'fake_B', 'rec_B']
...@@ -209,11 +196,13 @@ class MakeupModel(BaseModel): ...@@ -209,11 +196,13 @@ class MakeupModel(BaseModel):
"""Calculate GAN loss for discriminator D_A""" """Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B) fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self): def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B""" """Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A) fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self): def backward_G(self):
"""Calculate the loss for generators G_A and G_B""" """Calculate the loss for generators G_A and G_B"""
...@@ -258,6 +247,9 @@ class MakeupModel(BaseModel): ...@@ -258,6 +247,9 @@ class MakeupModel(BaseModel):
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B self.real_B) * lambda_B
self.losses['G_A_adv_loss'] = self.loss_G_A
self.losses['G_B_adv_loss'] = self.loss_G_B
mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1) mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1)
mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1) mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1)
...@@ -344,10 +336,8 @@ class MakeupModel(BaseModel): ...@@ -344,10 +336,8 @@ class MakeupModel(BaseModel):
self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his + self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his +
g_B_skin_loss_his * 0.1) * 0.01 g_B_skin_loss_his * 0.1) * 0.01
#self.loss_G_A_his = self.criterionL1(tmp_1, tmp_2) * 2048 * 255 self.losses['G_A_his_loss'] = self.loss_G_A_his
#tmp_3 = self.hm_gt_B*self.hm_mask_weight_B self.losses['G_B_his_loss'] = self.loss_G_A_his
#tmp_4 = self.fake_B*self.hm_mask_weight_B
#self.loss_G_B_his = self.criterionL1(tmp_3, tmp_4) * 2048 * 255
#vgg loss #vgg loss
vgg_s = self.vgg(self.real_A) vgg_s = self.vgg(self.real_A)
...@@ -366,6 +356,11 @@ class MakeupModel(BaseModel): ...@@ -366,6 +356,11 @@ class MakeupModel(BaseModel):
self.loss_A_vgg + self.loss_B_vgg) * 0.2 self.loss_A_vgg + self.loss_B_vgg) * 0.2
self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.2 self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.2
self.losses['G_A_vgg_loss'] = self.loss_A_vgg
self.losses['G_B_vgg_loss'] = self.loss_B_vgg
self.losses['G_rec_loss'] = self.loss_rec
self.losses['G_idt_loss'] = self.loss_idt
# bg consistency loss # bg consistency loss
mask_A_consis = paddle.cast( mask_A_consis = paddle.cast(
(self.mask_A == 0), dtype='float32') + paddle.cast( (self.mask_A == 0), dtype='float32') + paddle.cast(
......
...@@ -31,7 +31,6 @@ class Pix2PixModel(BaseModel): ...@@ -31,7 +31,6 @@ class Pix2PixModel(BaseModel):
""" """
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real_A', 'fake_B', 'real_B'] self.visual_names = ['real_A', 'fake_B', 'real_B']
# specify the models you want to save to the disk. # specify the models you want to save to the disk.
...@@ -114,6 +113,9 @@ class Pix2PixModel(BaseModel): ...@@ -114,6 +113,9 @@ class Pix2PixModel(BaseModel):
else: else:
self.loss_D.backward() self.loss_D.backward()
self.losses['D_fake_loss'] = self.loss_D_fake
self.losses['D_real_loss'] = self.loss_D_real
def backward_G(self): def backward_G(self):
"""Calculate GAN and L1 loss for the generator""" """Calculate GAN and L1 loss for the generator"""
# First, G(A) should fake the discriminator # First, G(A) should fake the discriminator
...@@ -134,6 +136,9 @@ class Pix2PixModel(BaseModel): ...@@ -134,6 +136,9 @@ class Pix2PixModel(BaseModel):
else: else:
self.loss_G.backward() self.loss_G.backward()
self.losses['G_adv_loss'] = self.loss_G_GAN
self.losses['G_L1_loss'] = self.loss_G_L1
def optimize_parameters(self): def optimize_parameters(self):
# compute fake images: G(A) # compute fake images: G(A)
self.forward() self.forward()
......
...@@ -45,7 +45,7 @@ def vgg16(pretrained=False): ...@@ -45,7 +45,7 @@ def vgg16(pretrained=False):
if pretrained: if pretrained:
weight_path = get_weights_path_from_url(model_urls['vgg16'][0], weight_path = get_weights_path_from_url(model_urls['vgg16'][0],
model_urls['vgg16'][1]) model_urls['vgg16'][1])
param, _ = paddle.load(weight_path) param = paddle.load(weight_path)
model.load_dict(param) model.load_dict(param)
return model return model
...@@ -80,7 +80,7 @@ def calculate_gain(nonlinearity, param=None): ...@@ -80,7 +80,7 @@ def calculate_gain(nonlinearity, param=None):
@paddle.no_grad() @paddle.no_grad()
def constant_(x, value): def constant_(x, value):
temp_value = paddle.fill_constant(x.shape, x.dtype, value) temp_value = paddle.full(x.shape, value, x.dtype)
x.set_value(temp_value) x.set_value(temp_value)
return x return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册