loss.py 1021 字节
Newer Older
E
update  
Eric.Lee 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#-*-coding:utf-8-*-
# date:2019-05-20
# function: wing loss
import torch
import torch.nn as nn
import torch.optim as optim
import os
import math

def wing_loss(landmarks, labels, w=0.06, epsilon=0.01):
        """
        Arguments:
            landmarks, labels: float tensors with shape [batch_size, landmarks].  landmarks means x1,x2,x3,x4...y1,y2,y3,y4   1-D
            w, epsilon: a float numbers.
        Returns:
            a float tensor with shape [].
        """

        x = landmarks - labels
        c = w * (1.0 - math.log(1.0 + w / epsilon))
        absolute_x = torch.abs(x)

        losses = torch.where(\
        (w>absolute_x),\
        w * torch.log(1.0 + absolute_x / epsilon),\
        absolute_x - c)


        # loss = tf.reduce_mean(tf.reduce_mean(losses, axis=[1]), axis=0)
        losses = torch.mean(losses,dim=1,keepdim=True)
        loss = torch.mean(losses)
        return loss

def got_total_wing_loss(output,crop_landmarks):
    loss = wing_loss(output, crop_landmarks)
    return loss