提交 ec056e11 编写于 作者: L lijianshe02

refine code

上级 afd50f2e
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import argparse import argparse
import paddle import paddle
from ppgan.first_order_predictor import FirstOrderPredictor from ppgan.apps.first_order_predictor import FirstOrderPredictor
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None, help="path to config") parser.add_argument("--config", default=None, help="path to config")
......
...@@ -26,7 +26,7 @@ class BaseModel(ABC): ...@@ -26,7 +26,7 @@ class BaseModel(ABC):
When creating your custom class, you need to implement your own initialization. When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)> In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists: Then, you need to define four lists:
-- self.loss (str list): specify the training losses that you want to plot and save. -- self.losses (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training. -- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save. -- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
...@@ -37,7 +37,7 @@ class BaseModel(ABC): ...@@ -37,7 +37,7 @@ class BaseModel(ABC):
opt.output_dir, opt.output_dir,
opt.model.name) # save all the checkpoints to save_dir opt.model.name) # save all the checkpoints to save_dir
self.loss = OrderedDict() self.losses = OrderedDict()
self.model_names = [] self.model_names = []
self.visual_names = [] self.visual_names = []
self.optimizers = [] self.optimizers = []
...@@ -115,7 +115,7 @@ class BaseModel(ABC): ...@@ -115,7 +115,7 @@ class BaseModel(ABC):
def get_current_losses(self): def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
return self.loss return self.losses
def set_requires_grad(self, nets, requires_grad=False): def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
......
...@@ -161,13 +161,13 @@ class CycleGANModel(BaseModel): ...@@ -161,13 +161,13 @@ class CycleGANModel(BaseModel):
"""Calculate GAN loss for discriminator D_A""" """Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B) fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss['D_A_loss'] = self.loss_D_A self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self): def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B""" """Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A) fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss['D_B_loss'] = self.loss_D_B self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self): def backward_G(self):
"""Calculate the loss for generators G_A and G_B""" """Calculate the loss for generators G_A and G_B"""
...@@ -199,12 +199,12 @@ class CycleGANModel(BaseModel): ...@@ -199,12 +199,12 @@ class CycleGANModel(BaseModel):
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B self.real_B) * lambda_B
self.loss['G_idt_A_loss'] = self.loss_idt_A self.losses['G_idt_A_loss'] = self.loss_idt_A
self.loss['G_idt_B_loss'] = self.loss_idt_B self.losses['G_idt_B_loss'] = self.loss_idt_B
self.loss['G_A_adv_loss'] = self.loss_G_A self.losses['G_A_adv_loss'] = self.loss_G_A
self.loss['G_B_adv_loss'] = self.loss_G_B self.losses['G_B_adv_loss'] = self.loss_G_B
self.loss['G_A_cycle_loss'] = self.loss_cycle_A self.losses['G_A_cycle_loss'] = self.loss_cycle_A
self.loss['G_B_cycle_loss'] = self.loss_cycle_B self.losses['G_B_cycle_loss'] = self.loss_cycle_B
# combined loss and calculate gradients # combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
......
...@@ -196,13 +196,13 @@ class MakeupModel(BaseModel): ...@@ -196,13 +196,13 @@ class MakeupModel(BaseModel):
"""Calculate GAN loss for discriminator D_A""" """Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B) fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss['D_A_loss'] = self.loss_D_A self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self): def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B""" """Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A) fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss['D_B_loss'] = self.loss_D_B self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self): def backward_G(self):
"""Calculate the loss for generators G_A and G_B""" """Calculate the loss for generators G_A and G_B"""
...@@ -247,8 +247,8 @@ class MakeupModel(BaseModel): ...@@ -247,8 +247,8 @@ class MakeupModel(BaseModel):
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B self.real_B) * lambda_B
self.loss['G_A_adv_loss'] = self.loss_G_A self.losses['G_A_adv_loss'] = self.loss_G_A
self.loss['G_B_adv_loss'] = self.loss_G_B self.losses['G_B_adv_loss'] = self.loss_G_B
mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1) mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1)
mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1) mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1)
...@@ -336,8 +336,8 @@ class MakeupModel(BaseModel): ...@@ -336,8 +336,8 @@ class MakeupModel(BaseModel):
self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his + self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his +
g_B_skin_loss_his * 0.1) * 0.01 g_B_skin_loss_his * 0.1) * 0.01
self.loss['G_A_his_loss'] = self.loss_G_A_his self.losses['G_A_his_loss'] = self.loss_G_A_his
self.loss['G_B_his_loss'] = self.loss_G_A_his self.losses['G_B_his_loss'] = self.loss_G_A_his
#vgg loss #vgg loss
vgg_s = self.vgg(self.real_A) vgg_s = self.vgg(self.real_A)
...@@ -356,10 +356,10 @@ class MakeupModel(BaseModel): ...@@ -356,10 +356,10 @@ class MakeupModel(BaseModel):
self.loss_A_vgg + self.loss_B_vgg) * 0.2 self.loss_A_vgg + self.loss_B_vgg) * 0.2
self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.2 self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.2
self.loss['G_A_vgg_loss'] = self.loss_A_vgg self.losses['G_A_vgg_loss'] = self.loss_A_vgg
self.loss['G_B_vgg_loss'] = self.loss_B_vgg self.losses['G_B_vgg_loss'] = self.loss_B_vgg
self.loss['G_rec_loss'] = self.loss_rec self.losses['G_rec_loss'] = self.loss_rec
self.loss['G_idt_loss'] = self.loss_idt self.losses['G_idt_loss'] = self.loss_idt
# bg consistency loss # bg consistency loss
mask_A_consis = paddle.cast( mask_A_consis = paddle.cast(
......
...@@ -113,8 +113,8 @@ class Pix2PixModel(BaseModel): ...@@ -113,8 +113,8 @@ class Pix2PixModel(BaseModel):
else: else:
self.loss_D.backward() self.loss_D.backward()
self.loss['D_fake_loss'] = self.loss_D_fake self.losses['D_fake_loss'] = self.loss_D_fake
self.loss['D_real_loss'] = self.loss_D_real self.losses['D_real_loss'] = self.loss_D_real
def backward_G(self): def backward_G(self):
"""Calculate GAN and L1 loss for the generator""" """Calculate GAN and L1 loss for the generator"""
...@@ -136,8 +136,8 @@ class Pix2PixModel(BaseModel): ...@@ -136,8 +136,8 @@ class Pix2PixModel(BaseModel):
else: else:
self.loss_G.backward() self.loss_G.backward()
self.loss['G_adv_loss'] = self.loss_G_GAN self.losses['G_adv_loss'] = self.loss_G_GAN
self.loss['G_L1_loss'] = self.loss_G_L1 self.losses['G_L1_loss'] = self.loss_G_L1
def optimize_parameters(self): def optimize_parameters(self):
# compute fake images: G(A) # compute fake images: G(A)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册