From da7ef08933ca4f239cb3d2b49bdeec5b1568ad4e Mon Sep 17 00:00:00 2001 From: "Eric.Lee2021" <305141918@qq.com> Date: Tue, 11 May 2021 11:38:51 +0800 Subject: [PATCH] Update loss.py --- loss/loss.py | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/loss/loss.py b/loss/loss.py index 620db1d..d191153 100644 --- a/loss/loss.py +++ b/loss/loss.py @@ -38,36 +38,3 @@ def wing_loss(landmarks, labels, w=10., epsilon=2.): def got_total_wing_loss(output,crop_landmarks): loss = wing_loss(output, crop_landmarks) return loss - -''' -AdaptiveWingLoss -''' - -class AdaptiveWingLoss(nn.Module): - def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1): - super(AdaptiveWingLoss, self).__init__() - self.omega = omega - self.theta = theta - self.epsilon = epsilon - self.alpha = alpha - - def forward(self, pred, target): - ''' - :param pred: BxNxHxH - :param target: BxNxHxH - :return: - ''' - - y = target - y_hat = pred - delta_y = (y - y_hat).abs() - delta_y1 = delta_y[delta_y < self.theta] - delta_y2 = delta_y[delta_y >= self.theta] - y1 = y[delta_y < self.theta] - y2 = y[delta_y >= self.theta] - loss1 = self.omega * torch.log(1 + torch.pow(delta_y1 / self.omega, self.alpha - y1)) - A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * ( - torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon) - C = self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2)) - loss2 = A * delta_y2 - C - return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2)) -- GitLab