loss.py 1.1 KB
Newer Older
E
update  
Eric.Lee 已提交
1 2
#-*-coding:utf-8-*-
# date:2019-05-20
Eric.Lee2021's avatar
Eric.Lee2021 已提交
3
# function: wing loss  & adaptive wing loss
E
update  
Eric.Lee 已提交
4 5 6 7 8 9
import torch
import torch.nn as nn
import torch.optim as optim
import os
import math

Eric.Lee2021's avatar
Eric.Lee2021 已提交
10 11 12 13 14
'''
wing_loss
'''
# def wing_loss(landmarks, labels, w=0.06, epsilon=0.01):
def wing_loss(landmarks, labels, w=10., epsilon=2.):
E
update  
Eric.Lee 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
        """
        Arguments:
            landmarks, labels: float tensors with shape [batch_size, landmarks].  landmarks means x1,x2,x3,x4...y1,y2,y3,y4   1-D
            w, epsilon: a float numbers.
        Returns:
            a float tensor with shape [].
        """

        x = landmarks - labels
        c = w * (1.0 - math.log(1.0 + w / epsilon))
        absolute_x = torch.abs(x)

        losses = torch.where(\
        (w>absolute_x),\
        w * torch.log(1.0 + absolute_x / epsilon),\
        absolute_x - c)


        # loss = tf.reduce_mean(tf.reduce_mean(losses, axis=[1]), axis=0)
        losses = torch.mean(losses,dim=1,keepdim=True)
        loss = torch.mean(losses)
        return loss

def got_total_wing_loss(output,crop_landmarks):
    loss = wing_loss(output, crop_landmarks)
    return loss