未验证 提交 6c75d650 编写于 作者: L lijianshe02 提交者: GitHub

Merge pull request #39 from lijianshe02/master

refine loss related code and align to the newest develop api
......@@ -15,7 +15,7 @@
import argparse
import paddle
from ppgan.first_order_predictor import FirstOrderPredictor
from ppgan.apps.first_order_predictor import FirstOrderPredictor
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None, help="path to config")
......
......@@ -26,7 +26,7 @@ class BaseModel(ABC):
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.losses (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
......@@ -37,7 +37,7 @@ class BaseModel(ABC):
opt.output_dir,
opt.model.name) # save all the checkpoints to save_dir
self.loss_names = []
self.losses = OrderedDict()
self.model_names = []
self.visual_names = []
self.optimizers = []
......@@ -115,13 +115,7 @@ class BaseModel(ABC):
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(
getattr(self, 'loss_' + name)
) # float(...) works for both scalar tensor and float number
return errors_ret
return self.losses
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
......
......@@ -31,10 +31,6 @@ class CycleGANModel(BaseModel):
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = [
'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
......@@ -165,11 +161,13 @@ class CycleGANModel(BaseModel):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
......@@ -200,6 +198,13 @@ class CycleGANModel(BaseModel):
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
self.losses['G_idt_A_loss'] = self.loss_idt_A
self.losses['G_idt_B_loss'] = self.loss_idt_B
self.losses['G_A_adv_loss'] = self.loss_G_A
self.losses['G_B_adv_loss'] = self.loss_G_B
self.losses['G_A_cycle_loss'] = self.loss_cycle_A
self.losses['G_B_cycle_loss'] = self.loss_cycle_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
......
......@@ -47,16 +47,16 @@ class GANLoss(nn.Layer):
"""
if target_is_real:
if not hasattr(self, 'target_real_tensor'):
self.target_real_tensor = paddle.fill_constant(
self.target_real_tensor = paddle.full(
shape=paddle.shape(prediction),
value=self.target_real_label,
fill_value=self.target_real_label,
dtype='float32')
target_tensor = self.target_real_tensor
else:
if not hasattr(self, 'target_fake_tensor'):
self.target_fake_tensor = paddle.fill_constant(
self.target_fake_tensor = paddle.full(
shape=paddle.shape(prediction),
value=self.target_fake_label,
fill_value=self.target_fake_label,
dtype='float32')
target_tensor = self.target_fake_tensor
......
......@@ -50,19 +50,6 @@ class MakeupModel(BaseModel):
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = [
'D_A',
'G_A',
'rec',
'idt',
'D_B',
'G_B',
'G_A_his',
'G_B_his',
'G_bg_consis',
'A_vgg',
'B_vgg',
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_A', 'rec_A']
visual_names_B = ['real_B', 'fake_B', 'rec_B']
......@@ -209,11 +196,13 @@ class MakeupModel(BaseModel):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
......@@ -258,6 +247,9 @@ class MakeupModel(BaseModel):
self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
self.losses['G_A_adv_loss'] = self.loss_G_A
self.losses['G_B_adv_loss'] = self.loss_G_B
mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1)
mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1)
......@@ -344,10 +336,8 @@ class MakeupModel(BaseModel):
self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his +
g_B_skin_loss_his * 0.1) * 0.01
#self.loss_G_A_his = self.criterionL1(tmp_1, tmp_2) * 2048 * 255
#tmp_3 = self.hm_gt_B*self.hm_mask_weight_B
#tmp_4 = self.fake_B*self.hm_mask_weight_B
#self.loss_G_B_his = self.criterionL1(tmp_3, tmp_4) * 2048 * 255
self.losses['G_A_his_loss'] = self.loss_G_A_his
self.losses['G_B_his_loss'] = self.loss_G_A_his
#vgg loss
vgg_s = self.vgg(self.real_A)
......@@ -366,6 +356,11 @@ class MakeupModel(BaseModel):
self.loss_A_vgg + self.loss_B_vgg) * 0.2
self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.2
self.losses['G_A_vgg_loss'] = self.loss_A_vgg
self.losses['G_B_vgg_loss'] = self.loss_B_vgg
self.losses['G_rec_loss'] = self.loss_rec
self.losses['G_idt_loss'] = self.loss_idt
# bg consistency loss
mask_A_consis = paddle.cast(
(self.mask_A == 0), dtype='float32') + paddle.cast(
......
......@@ -31,7 +31,6 @@ class Pix2PixModel(BaseModel):
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real_A', 'fake_B', 'real_B']
# specify the models you want to save to the disk.
......@@ -114,6 +113,9 @@ class Pix2PixModel(BaseModel):
else:
self.loss_D.backward()
self.losses['D_fake_loss'] = self.loss_D_fake
self.losses['D_real_loss'] = self.loss_D_real
def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
# First, G(A) should fake the discriminator
......@@ -134,6 +136,9 @@ class Pix2PixModel(BaseModel):
else:
self.loss_G.backward()
self.losses['G_adv_loss'] = self.loss_G_GAN
self.losses['G_L1_loss'] = self.loss_G_L1
def optimize_parameters(self):
# compute fake images: G(A)
self.forward()
......
......@@ -45,7 +45,7 @@ def vgg16(pretrained=False):
if pretrained:
weight_path = get_weights_path_from_url(model_urls['vgg16'][0],
model_urls['vgg16'][1])
param, _ = paddle.load(weight_path)
param = paddle.load(weight_path)
model.load_dict(param)
return model
......@@ -80,7 +80,7 @@ def calculate_gain(nonlinearity, param=None):
@paddle.no_grad()
def constant_(x, value):
temp_value = paddle.fill_constant(x.shape, x.dtype, value)
temp_value = paddle.full(x.shape, value, x.dtype)
x.set_value(temp_value)
return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册