loss.py 2.3 KB
Newer Older
E
update  
Eric.Lee 已提交
1 2
#-*-coding:utf-8-*-
# date:2019-05-20
Eric.Lee2021's avatar
Eric.Lee2021 已提交
3
# function: wing loss  & adaptive wing loss
E
update  
Eric.Lee 已提交
4 5 6 7 8 9
import torch
import torch.nn as nn
import torch.optim as optim
import os
import math

Eric.Lee2021's avatar
Eric.Lee2021 已提交
10 11 12 13 14
'''
wing_loss
'''
# def wing_loss(landmarks, labels, w=0.06, epsilon=0.01):
def wing_loss(landmarks, labels, w=10., epsilon=2.):
E
update  
Eric.Lee 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
        """
        Arguments:
            landmarks, labels: float tensors with shape [batch_size, landmarks].  landmarks means x1,x2,x3,x4...y1,y2,y3,y4   1-D
            w, epsilon: a float numbers.
        Returns:
            a float tensor with shape [].
        """

        x = landmarks - labels
        c = w * (1.0 - math.log(1.0 + w / epsilon))
        absolute_x = torch.abs(x)

        losses = torch.where(\
        (w>absolute_x),\
        w * torch.log(1.0 + absolute_x / epsilon),\
        absolute_x - c)


        # loss = tf.reduce_mean(tf.reduce_mean(losses, axis=[1]), axis=0)
        losses = torch.mean(losses,dim=1,keepdim=True)
        loss = torch.mean(losses)
        return loss

def got_total_wing_loss(output,crop_landmarks):
    loss = wing_loss(output, crop_landmarks)
    return loss
Eric.Lee2021's avatar
Eric.Lee2021 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

'''
AdaptiveWingLoss
'''

class AdaptiveWingLoss(nn.Module):
    def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, pred, target):
        '''
        :param pred: BxNxHxH
        :param target: BxNxHxH
        :return:
        '''

        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.theta]
        delta_y2 = delta_y[delta_y >= self.theta]
        y1 = y[delta_y < self.theta]
        y2 = y[delta_y >= self.theta]
        loss1 = self.omega * torch.log(1 + torch.pow(delta_y1 / self.omega, self.alpha - y1))
        A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
            torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
        loss2 = A * delta_y2 - C
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))