losses.py 1.2 KB
Newer Older
1 2 3 4 5 6
import torch
import torch.nn as nn
import torch.nn.functional as F

from .wdtypes import *

J
jrzaurin 已提交
7 8 9
use_cuda = torch.cuda.is_available()


10
class FocalLoss(nn.Module):
J
jrzaurin 已提交
11
    def __init__(self, alpha: float = 0.25, gamma: float = 1.0):
12 13 14 15
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

J
jrzaurin 已提交
16
    def get_weight(self, x: Tensor, t: Tensor) -> Tensor:
17
        p = x.sigmoid()
J
jrzaurin 已提交
18 19 20
        pt = p * t + (1 - p) * (1 - t)  # type: ignore
        w = self.alpha * t + (1 - self.alpha) * (1 - t)  # type: ignore
        return (w * (1 - pt).pow(self.gamma)).detach()  # type: ignore
21

J
jrzaurin 已提交
22
    def forward(self, input: Tensor, target: Tensor) -> Tensor:  # type: ignore
23
        if input.size(1) == 1:
J
jrzaurin 已提交
24
            input = torch.cat([1 - input, input], axis=1)  # type: ignore
25
            num_class = 2
J
jrzaurin 已提交
26 27
        else:
            num_class = input.size(1)
J
jrzaurin 已提交
28
        binary_target = torch.eye(num_class)[target.long()]
J
jrzaurin 已提交
29 30
        if use_cuda:
            binary_target = binary_target.cuda()
31 32
        binary_target = binary_target.contiguous()
        weight = self.get_weight(input, binary_target)
J
jrzaurin 已提交
33 34 35
        return F.binary_cross_entropy_with_logits(
            input, binary_target, weight, reduction="mean"
        )