#-*-coding:utf-8-*- # date:2019-05-20 # function: wing loss & adaptive wing loss import torch import torch.nn as nn import torch.optim as optim import os import math ''' wing_loss ''' # def wing_loss(landmarks, labels, w=0.06, epsilon=0.01): def wing_loss(landmarks, labels, w=10., epsilon=2.): """ Arguments: landmarks, labels: float tensors with shape [batch_size, landmarks]. landmarks means x1,x2,x3,x4...y1,y2,y3,y4 1-D w, epsilon: a float numbers. Returns: a float tensor with shape []. """ x = landmarks - labels c = w * (1.0 - math.log(1.0 + w / epsilon)) absolute_x = torch.abs(x) losses = torch.where(\ (w>absolute_x),\ w * torch.log(1.0 + absolute_x / epsilon),\ absolute_x - c) # loss = tf.reduce_mean(tf.reduce_mean(losses, axis=[1]), axis=0) losses = torch.mean(losses,dim=1,keepdim=True) loss = torch.mean(losses) return loss def got_total_wing_loss(output,crop_landmarks): loss = wing_loss(output, crop_landmarks) return loss ''' AdaptiveWingLoss ''' class AdaptiveWingLoss(nn.Module): def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1): super(AdaptiveWingLoss, self).__init__() self.omega = omega self.theta = theta self.epsilon = epsilon self.alpha = alpha def forward(self, pred, target): ''' :param pred: BxNxHxH :param target: BxNxHxH :return: ''' y = target y_hat = pred delta_y = (y - y_hat).abs() delta_y1 = delta_y[delta_y < self.theta] delta_y2 = delta_y[delta_y >= self.theta] y1 = y[delta_y < self.theta] y2 = y[delta_y >= self.theta] loss1 = self.omega * torch.log(1 + torch.pow(delta_y1 / self.omega, self.alpha - y1)) A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * ( torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon) C = self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2)) loss2 = A * delta_y2 - C return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))