提交 ea8f5763 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

add adaptive_wing_loss

上级 c5c3dbd3
#-*-coding:utf-8-*- #-*-coding:utf-8-*-
# date:2019-05-20 # date:2019-05-20
# function: wing loss # function: wing loss & adaptive wing loss
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import os import os
import math import math
def wing_loss(landmarks, labels, w=0.06, epsilon=0.01): '''
wing_loss
'''
# def wing_loss(landmarks, labels, w=0.06, epsilon=0.01):
def wing_loss(landmarks, labels, w=10., epsilon=2.):
""" """
Arguments: Arguments:
landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D
...@@ -34,3 +38,36 @@ def wing_loss(landmarks, labels, w=0.06, epsilon=0.01): ...@@ -34,3 +38,36 @@ def wing_loss(landmarks, labels, w=0.06, epsilon=0.01):
def got_total_wing_loss(output,crop_landmarks): def got_total_wing_loss(output,crop_landmarks):
loss = wing_loss(output, crop_landmarks) loss = wing_loss(output, crop_landmarks)
return loss return loss
'''
AdaptiveWingLoss
'''
class AdaptiveWingLoss(nn.Module):
def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
super(AdaptiveWingLoss, self).__init__()
self.omega = omega
self.theta = theta
self.epsilon = epsilon
self.alpha = alpha
def forward(self, pred, target):
'''
:param pred: BxNxHxH
:param target: BxNxHxH
:return:
'''
y = target
y_hat = pred
delta_y = (y - y_hat).abs()
delta_y1 = delta_y[delta_y < self.theta]
delta_y2 = delta_y[delta_y >= self.theta]
y1 = y[delta_y < self.theta]
y2 = y[delta_y >= self.theta]
loss1 = self.omega * torch.log(1 + torch.pow(delta_y1 / self.omega, self.alpha - y1))
A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
C = self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
loss2 = A * delta_y2 - C
return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))
...@@ -99,8 +99,10 @@ def trainer(ops,f_log): ...@@ -99,8 +99,10 @@ def trainer(ops,f_log):
print('/**********************************************/') print('/**********************************************/')
# 损失函数 # 损失函数
if ops.loss_define != 'wing_loss': if ops.loss_define == 'mse_loss':
criterion = nn.MSELoss(reduce=True, reduction='mean') criterion = nn.MSELoss(reduce=True, reduction='mean')
elif ops.loss_define == 'adaptive_wing_loss':
criterion = AdaptiveWingLoss()
step = 0 step = 0
idx = 0 idx = 0
...@@ -153,7 +155,7 @@ def trainer(ops,f_log): ...@@ -153,7 +155,7 @@ def trainer(ops,f_log):
print(' %s - %s - epoch [%s/%s] (%s/%s):'%(loc_time,ops.model,epoch,ops.epochs,i,int(dataset.__len__()/ops.batch_size)),\ print(' %s - %s - epoch [%s/%s] (%s/%s):'%(loc_time,ops.model,epoch,ops.epochs,i,int(dataset.__len__()/ops.batch_size)),\
'Mean Loss : %.6f - Loss: %.6f'%(loss_mean/loss_idx,loss.item()),\ 'Mean Loss : %.6f - Loss: %.6f'%(loss_mean/loss_idx,loss.item()),\
' lr : %.8f'%init_lr,' bs :',ops.batch_size,\ ' lr : %.8f'%init_lr,' bs :',ops.batch_size,\
' img_size: %s x %s'%(ops.img_size[0],ops.img_size[1]),' best_loss: %.6f'%best_loss) ' img_size: %s x %s'%(ops.img_size[0],ops.img_size[1]),' best_loss: %.6f'%best_loss, " {}".format(ops.loss_define))
# 计算梯度 # 计算梯度
loss.backward() loss.backward()
# 优化器对模型参数更新 # 优化器对模型参数更新
...@@ -162,7 +164,7 @@ def trainer(ops,f_log): ...@@ -162,7 +164,7 @@ def trainer(ops,f_log):
optimizer.zero_grad() optimizer.zero_grad()
step += 1 step += 1
torch.save(model_.state_dict(), ops.model_exp + '{}-size-{}-model_epoch-{}.pth'.format(ops.model,ops.img_size[0],epoch)) torch.save(model_.state_dict(), ops.model_exp + '{}-size-{}-loss-{}-model_epoch-{}.pth'.format(ops.model,ops.img_size[0],ops.loss_define,epoch))
except Exception as e: except Exception as e:
print('Exception : ',e) # 打印异常 print('Exception : ',e) # 打印异常
...@@ -192,8 +194,8 @@ if __name__ == "__main__": ...@@ -192,8 +194,8 @@ if __name__ == "__main__":
help = 'imageNet_Pretrain') # 初始化学习率 help = 'imageNet_Pretrain') # 初始化学习率
parser.add_argument('--fintune_model', type=str, default = 'None', parser.add_argument('--fintune_model', type=str, default = 'None',
help = 'fintune_model') # fintune model help = 'fintune_model') # fintune model
parser.add_argument('--loss_define', type=str, default = 'wing_loss', parser.add_argument('--loss_define', type=str, default = 'mse_loss',
help = 'define_loss') # 损失函数定义 help = 'define_loss : wing_loss, mse_loss ,adaptive_wing_loss') # 损失函数定义
parser.add_argument('--init_lr', type=float, default = 1e-3, parser.add_argument('--init_lr', type=float, default = 1e-3,
help = 'init learning Rate') # 初始化学习率 help = 'init learning Rate') # 初始化学习率
parser.add_argument('--lr_decay', type=float, default = 0.1, parser.add_argument('--lr_decay', type=float, default = 0.1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册