提交 da7ef089 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

Update loss.py

上级 6aa7d7bb
......@@ -38,36 +38,3 @@ def wing_loss(landmarks, labels, w=10., epsilon=2.):
def got_total_wing_loss(output,crop_landmarks):
loss = wing_loss(output, crop_landmarks)
return loss
'''
AdaptiveWingLoss
'''
class AdaptiveWingLoss(nn.Module):
def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
super(AdaptiveWingLoss, self).__init__()
self.omega = omega
self.theta = theta
self.epsilon = epsilon
self.alpha = alpha
def forward(self, pred, target):
'''
:param pred: BxNxHxH
:param target: BxNxHxH
:return:
'''
y = target
y_hat = pred
delta_y = (y - y_hat).abs()
delta_y1 = delta_y[delta_y < self.theta]
delta_y2 = delta_y[delta_y >= self.theta]
y1 = y[delta_y < self.theta]
y2 = y[delta_y >= self.theta]
loss1 = self.omega * torch.log(1 + torch.pow(delta_y1 / self.omega, self.alpha - y1))
A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
C = self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
loss2 = A * delta_y2 - C
return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册