import torch import torch.nn as nn import torch.nn.functional as F from .wdtypes import * class FocalLoss(nn.Module): def __init__(self, alpha:float=0.25, gamma:float=1.): super().__init__() self.alpha = alpha self.gamma = gamma def get_weight(self, x:Tensor, t:Tensor) -> Tensor: self.alpha,self.gamma = 0.25,1 p = x.sigmoid() pt = p*t + (1-p)*(1-t) w = self.alpha*t + (1-self.alpha)*(1-t) return (w * (1-pt).pow(self.gamma)).detach() def forward(self, input:Tensor, target:Tensor) -> Tensor: if input.size(1) == 1: input = torch.cat([1-input, input], axis=1) num_class = 2 else: num_class = input.size(1) binary_target = torch.eye(num_class)[target.data.cpu().long()] binary_target = binary_target.contiguous() weight = self.get_weight(input, binary_target) return F.binary_cross_entropy_with_logits(input, binary_target, weight, reduction='sum')/num_class