losses.py 8.5 KB
Newer Older
1 2 3 4
import torch
import torch.nn as nn
import torch.nn.functional as F

5
from pytorch_widedeep.wdtypes import *  # noqa: F403
6

J
jrzaurin 已提交
7 8
use_cuda = torch.cuda.is_available()

9

10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
class TweedieLoss(nn.Module):
    """
    Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calcualted as

    Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
    All credits go to `pytorch-forecasting
    <https://pytorch-forecasting.readthedocs.io/en/latest/_modules/pytorch_forecasting/metrics.html#QuantileLoss>`
    """

    def __init__()
        super().__init__()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        # calculate quantile loss
        losses = []
        for i, q in enumerate(self.quantiles):
            errors = target - input[..., i]
            losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
        losses = torch.cat(losses, dim=2)

        return losses


class QuantileLoss(nn.Module):
    """
    Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calcualted as

    Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
    All credits go to `pytorch-forecasting
    <https://pytorch-forecasting.readthedocs.io/en/latest/_modules/pytorch_forecasting/metrics.html#QuantileLoss>`
    """

    def __init__(
        self,
        quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
        **kwargs,
    ):
        """
        Quantile loss

        Args:
            quantiles: quantiles for metric
        """
        super().__init__(quantiles=quantiles, **kwargs)

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        # calculate quantile loss
        losses = []
        for i, q in enumerate(self.quantiles):
            errors = target - input[..., i]
            losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
        losses = torch.cat(losses, dim=2)

        return losses


P
Pavol Mulinka 已提交
66
class ZILNLoss(nn.Module):
67 68 69
    r"""Adjusted implementation of the `Zero Inflated LogNormal loss
    <https://arxiv.org/pdf/1912.07753.pdf>` and its `code
    <https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py>`
P
Pavol Mulinka 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    """

    def __init__(self):
        super().__init__()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import ZILNLoss
        >>>
        >>> # REGRESSION
        >>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
        >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
        >>> ZILNLoss()(input, target)
        tensor([0.6287, 1.9941])
        """
        positive = target>0
        positive = positive.float()

        assert input.shape == torch.Size([target.shape[0], 3]), "Wrong shape of input."
        positive_input = input[..., :1]

102
        classification_loss = F.binary_cross_entropy_with_logits(positive_input, positive, reduction="none").flatten()
P
Pavol Mulinka 已提交
103 104 105 106 107 108 109 110 111 112 113 114

        loc = input[..., 1:2]
        scale = torch.maximum(
            F.softplus(input[..., 2:]),
            torch.sqrt(torch.Tensor([torch.finfo(torch.float32).eps])))
        safe_labels = positive * target + (
            1 - positive) * torch.ones_like(target)

        regression_loss = -torch.mean(
            positive * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),
            dim=-1)

P
Pavol Mulinka 已提交
115
        return torch.mean(classification_loss + regression_loss)
P
Pavol Mulinka 已提交
116 117


118
class FocalLoss(nn.Module):
119 120 121
    r"""Implementation of the `focal loss
    <https://arxiv.org/pdf/1708.02002.pdf>`_ for both binary and
    multiclass classification
122

123
    :math:`FL(p_t) = \alpha (1 - p_t)^{\gamma} log(p_t)`
124

125
    where, for a case of a binary classification problem
126

127
    :math:`\begin{equation} p_t= \begin{cases}p, & \text{if $y=1$}.\\1-p, & \text{otherwise}. \end{cases} \end{equation}`
128

129 130 131 132 133 134 135 136 137
    Parameters
    ----------
    alpha: float
        Focal Loss ``alpha`` parameter
    gamma: float
        Focal Loss ``gamma`` parameter
    """

    def __init__(self, alpha: float = 0.25, gamma: float = 1.0):
138 139 140 141
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

142
    def _get_weight(self, p: Tensor, t: Tensor) -> Tensor:
J
jrzaurin 已提交
143 144 145
        pt = p * t + (1 - p) * (1 - t)  # type: ignore
        w = self.alpha * t + (1 - self.alpha) * (1 - t)  # type: ignore
        return (w * (1 - pt).pow(self.gamma)).detach()  # type: ignore
146

147
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
148
        r"""
149 150 151 152 153 154 155 156 157 158
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
159
        >>>
160 161 162
        >>> from pytorch_widedeep.losses import FocalLoss
        >>>
        >>> # BINARY
163
        >>> target = torch.tensor([0, 1, 0, 1]).view(-1, 1)
164
        >>> input = torch.tensor([[0.6, 0.7, 0.3, 0.8]]).t()
165
        >>> FocalLoss()(input, target)
166
        tensor(0.1762)
167 168
        >>>
        >>> # MULTICLASS
169
        >>> target = torch.tensor([1, 0, 2]).view(-1, 1)
170
        >>> input = torch.tensor([[0.2, 0.5, 0.3], [0.8, 0.1, 0.1], [0.7, 0.2, 0.1]])
171
        >>> FocalLoss()(input, target)
172
        tensor(0.2573)
173 174
        """
        input_prob = torch.sigmoid(input)
175
        if input.size(1) == 1:
176
            input_prob = torch.cat([1 - input_prob, input_prob], axis=1)  # type: ignore
177
            num_class = 2
J
jrzaurin 已提交
178
        else:
179
            num_class = input_prob.size(1)
180
        binary_target = torch.eye(num_class)[target.squeeze().long()]
J
jrzaurin 已提交
181 182
        if use_cuda:
            binary_target = binary_target.cuda()
183
        binary_target = binary_target.contiguous()
184 185 186
        weight = self._get_weight(input_prob, binary_target)
        return F.binary_cross_entropy(
            input_prob, binary_target, weight, reduction="mean"
J
jrzaurin 已提交
187
        )
188 189 190


class MSLELoss(nn.Module):
191 192
    r"""mean squared log error"""

193 194 195 196 197
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        r"""
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import MSLELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> MSLELoss()(input, target)
        tensor(0.1115)
        """
216 217 218 219
        return self.mse(torch.log(input + 1), torch.log(target + 1))


class RMSELoss(nn.Module):
220 221
    r"""root mean squared error"""

222 223 224 225 226
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
        r"""
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import RMSELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> RMSELoss()(input, target)
        tensor(0.6964)
        """
245 246 247 248
        return torch.sqrt(self.mse(input, target))


class RMSLELoss(nn.Module):
249 250
    r"""root mean squared log error"""

251 252 253 254 255
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
        r"""
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import RMSLELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> RMSLELoss()(input, target)
        tensor(0.3339)
        """
274
        return torch.sqrt(self.mse(torch.log(input + 1), torch.log(target + 1)))