losses.py 1.0 KB
Newer Older
1 2 3 4 5 6
import torch
import torch.nn as nn
import torch.nn.functional as F

from .wdtypes import *

7

J
jrzaurin 已提交
8 9 10
use_cuda = torch.cuda.is_available()


11
class FocalLoss(nn.Module):
12
    def __init__(self, alpha:float=0.25, gamma:float=1.):
13 14 15 16 17 18 19 20
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def get_weight(self, x:Tensor, t:Tensor) -> Tensor:
        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)
        w = self.alpha*t + (1-self.alpha)*(1-t)
21
        return (w * (1-pt).pow(self.gamma)).detach()
22 23

    def forward(self, input:Tensor, target:Tensor) -> Tensor:
24 25 26 27
        if input.size(1) == 1:
            input = torch.cat([1-input, input], axis=1)
            num_class = 2
        else: num_class = input.size(1)
J
jrzaurin 已提交
28 29
        binary_target = torch.eye(num_class)[target.long()]
        if use_cuda: binary_target = binary_target.cuda()
30 31
        binary_target = binary_target.contiguous()
        weight = self.get_weight(input, binary_target)
32
        return F.binary_cross_entropy_with_logits(input, binary_target, weight, reduction='mean')