From ea8f57633b8e9a2e0cf7f5ba036ad7456fdc6fc5 Mon Sep 17 00:00:00 2001 From: "Eric.Lee2021" <305141918@qq.com> Date: Sat, 17 Apr 2021 00:19:13 +0800 Subject: [PATCH] add adaptive_wing_loss --- loss/loss.py | 41 +++++++++++++++++++++++++++++++++++++++-- train.py | 12 +++++++----- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/loss/loss.py b/loss/loss.py index 4a6b20c..620db1d 100644 --- a/loss/loss.py +++ b/loss/loss.py @@ -1,13 +1,17 @@ #-*-coding:utf-8-*- # date:2019-05-20 -# function: wing loss +# function: wing loss & adaptive wing loss import torch import torch.nn as nn import torch.optim as optim import os import math -def wing_loss(landmarks, labels, w=0.06, epsilon=0.01): +''' +wing_loss +''' +# def wing_loss(landmarks, labels, w=0.06, epsilon=0.01): +def wing_loss(landmarks, labels, w=10., epsilon=2.): """ Arguments: landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D @@ -34,3 +38,36 @@ def wing_loss(landmarks, labels, w=0.06, epsilon=0.01): def got_total_wing_loss(output,crop_landmarks): loss = wing_loss(output, crop_landmarks) return loss + +''' +AdaptiveWingLoss +''' + +class AdaptiveWingLoss(nn.Module): + def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1): + super(AdaptiveWingLoss, self).__init__() + self.omega = omega + self.theta = theta + self.epsilon = epsilon + self.alpha = alpha + + def forward(self, pred, target): + ''' + :param pred: BxNxHxH + :param target: BxNxHxH + :return: + ''' + + y = target + y_hat = pred + delta_y = (y - y_hat).abs() + delta_y1 = delta_y[delta_y < self.theta] + delta_y2 = delta_y[delta_y >= self.theta] + y1 = y[delta_y < self.theta] + y2 = y[delta_y >= self.theta] + loss1 = self.omega * torch.log(1 + torch.pow(delta_y1 / self.omega, self.alpha - y1)) + A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * ( + torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon) + C = self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2)) + loss2 = A * delta_y2 - C + return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2)) diff --git a/train.py b/train.py index 3c0fc25..f8bc737 100644 --- a/train.py +++ b/train.py @@ -99,8 +99,10 @@ def trainer(ops,f_log): print('/**********************************************/') # 损失函数 - if ops.loss_define != 'wing_loss': + if ops.loss_define == 'mse_loss': criterion = nn.MSELoss(reduce=True, reduction='mean') + elif ops.loss_define == 'adaptive_wing_loss': + criterion = AdaptiveWingLoss() step = 0 idx = 0 @@ -153,7 +155,7 @@ def trainer(ops,f_log): print(' %s - %s - epoch [%s/%s] (%s/%s):'%(loc_time,ops.model,epoch,ops.epochs,i,int(dataset.__len__()/ops.batch_size)),\ 'Mean Loss : %.6f - Loss: %.6f'%(loss_mean/loss_idx,loss.item()),\ ' lr : %.8f'%init_lr,' bs :',ops.batch_size,\ - ' img_size: %s x %s'%(ops.img_size[0],ops.img_size[1]),' best_loss: %.6f'%best_loss) + ' img_size: %s x %s'%(ops.img_size[0],ops.img_size[1]),' best_loss: %.6f'%best_loss, " {}".format(ops.loss_define)) # 计算梯度 loss.backward() # 优化器对模型参数更新 @@ -162,7 +164,7 @@ def trainer(ops,f_log): optimizer.zero_grad() step += 1 - torch.save(model_.state_dict(), ops.model_exp + '{}-size-{}-model_epoch-{}.pth'.format(ops.model,ops.img_size[0],epoch)) + torch.save(model_.state_dict(), ops.model_exp + '{}-size-{}-loss-{}-model_epoch-{}.pth'.format(ops.model,ops.img_size[0],ops.loss_define,epoch)) except Exception as e: print('Exception : ',e) # 打印异常 @@ -192,8 +194,8 @@ if __name__ == "__main__": help = 'imageNet_Pretrain') # 初始化学习率 parser.add_argument('--fintune_model', type=str, default = 'None', help = 'fintune_model') # fintune model - parser.add_argument('--loss_define', type=str, default = 'wing_loss', - help = 'define_loss') # 损失函数定义 + parser.add_argument('--loss_define', type=str, default = 'mse_loss', + help = 'define_loss : wing_loss, mse_loss ,adaptive_wing_loss') # 损失函数定义 parser.add_argument('--init_lr', type=float, default = 1e-3, help = 'init learning Rate') # 初始化学习率 parser.add_argument('--lr_decay', type=float, default = 0.1, -- GitLab