提交 ec056e11 编写于 作者: L lijianshe02

refine code

上级 afd50f2e
......@@ -15,7 +15,7 @@
import argparse
import paddle
from ppgan.first_order_predictor import FirstOrderPredictor
from ppgan.apps.first_order_predictor import FirstOrderPredictor
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None, help="path to config")
......
......@@ -26,7 +26,7 @@ class BaseModel(ABC):
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss (str list): specify the training losses that you want to plot and save.
-- self.losses (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
......@@ -37,7 +37,7 @@ class BaseModel(ABC):
opt.output_dir,
opt.model.name) # save all the checkpoints to save_dir
self.loss = OrderedDict()
self.losses = OrderedDict()
self.model_names = []
self.visual_names = []
self.optimizers = []
......@@ -115,7 +115,7 @@ class BaseModel(ABC):
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
return self.loss
return self.losses
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
......
......@@ -161,13 +161,13 @@ class CycleGANModel(BaseModel):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss['D_A_loss'] = self.loss_D_A
self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss['D_B_loss'] = self.loss_D_B
self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
......@@ -199,12 +199,12 @@ class CycleGANModel(BaseModel):
self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
self.loss['G_idt_A_loss'] = self.loss_idt_A
self.loss['G_idt_B_loss'] = self.loss_idt_B
self.loss['G_A_adv_loss'] = self.loss_G_A
self.loss['G_B_adv_loss'] = self.loss_G_B
self.loss['G_A_cycle_loss'] = self.loss_cycle_A
self.loss['G_B_cycle_loss'] = self.loss_cycle_B
self.losses['G_idt_A_loss'] = self.loss_idt_A
self.losses['G_idt_B_loss'] = self.loss_idt_B
self.losses['G_A_adv_loss'] = self.loss_G_A
self.losses['G_B_adv_loss'] = self.loss_G_B
self.losses['G_A_cycle_loss'] = self.loss_cycle_A
self.losses['G_B_cycle_loss'] = self.loss_cycle_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
......
......@@ -196,13 +196,13 @@ class MakeupModel(BaseModel):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss['D_A_loss'] = self.loss_D_A
self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss['D_B_loss'] = self.loss_D_B
self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
......@@ -247,8 +247,8 @@ class MakeupModel(BaseModel):
self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
self.loss['G_A_adv_loss'] = self.loss_G_A
self.loss['G_B_adv_loss'] = self.loss_G_B
self.losses['G_A_adv_loss'] = self.loss_G_A
self.losses['G_B_adv_loss'] = self.loss_G_B
mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1)
mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1)
......@@ -336,8 +336,8 @@ class MakeupModel(BaseModel):
self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his +
g_B_skin_loss_his * 0.1) * 0.01
self.loss['G_A_his_loss'] = self.loss_G_A_his
self.loss['G_B_his_loss'] = self.loss_G_A_his
self.losses['G_A_his_loss'] = self.loss_G_A_his
self.losses['G_B_his_loss'] = self.loss_G_A_his
#vgg loss
vgg_s = self.vgg(self.real_A)
......@@ -356,10 +356,10 @@ class MakeupModel(BaseModel):
self.loss_A_vgg + self.loss_B_vgg) * 0.2
self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.2
self.loss['G_A_vgg_loss'] = self.loss_A_vgg
self.loss['G_B_vgg_loss'] = self.loss_B_vgg
self.loss['G_rec_loss'] = self.loss_rec
self.loss['G_idt_loss'] = self.loss_idt
self.losses['G_A_vgg_loss'] = self.loss_A_vgg
self.losses['G_B_vgg_loss'] = self.loss_B_vgg
self.losses['G_rec_loss'] = self.loss_rec
self.losses['G_idt_loss'] = self.loss_idt
# bg consistency loss
mask_A_consis = paddle.cast(
......
......@@ -113,8 +113,8 @@ class Pix2PixModel(BaseModel):
else:
self.loss_D.backward()
self.loss['D_fake_loss'] = self.loss_D_fake
self.loss['D_real_loss'] = self.loss_D_real
self.losses['D_fake_loss'] = self.loss_D_fake
self.losses['D_real_loss'] = self.loss_D_real
def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
......@@ -136,8 +136,8 @@ class Pix2PixModel(BaseModel):
else:
self.loss_G.backward()
self.loss['G_adv_loss'] = self.loss_G_GAN
self.loss['G_L1_loss'] = self.loss_G_L1
self.losses['G_adv_loss'] = self.loss_G_GAN
self.losses['G_L1_loss'] = self.loss_G_L1
def optimize_parameters(self):
# compute fake images: G(A)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册