losses.py 9.6 KB
Newer Older
1 2 3 4
import torch
import torch.nn as nn
import torch.nn.functional as F

5
from pytorch_widedeep.wdtypes import *  # noqa: F403
6

J
jrzaurin 已提交
7 8
use_cuda = torch.cuda.is_available()

9

10 11
class TweedieLoss(nn.Module):
    """
P
Pavol Mulinka 已提交
12 13 14 15
    Tweedie loss for extremely unbalanced zero-inflated data``
    All credits go to `Wenbo Shi
    <https://towardsdatascience.com/tweedie-loss-function-for-right-skewed-data-2c5ca470678f> and
    <https://arxiv.org/abs/1811.10192>`
16 17
    """

P
Pavol Mulinka 已提交
18
    def __init__(self):
19 20
        super().__init__()

P
Pavol Mulinka 已提交
21
    def forward(self, input: Tensor, target: Tensor, p=1.5) -> Tensor:
P
Pavol Mulinka 已提交
22 23 24 25 26 27 28 29 30
        assert (
            input.min() > 0
        ), """All input values must be >=0, if your model is predicting
            values <0 try to enforce positive values by activation function
            on last layer with `trainer.enforce_positive_output=True`"""
        assert target.min() >= 0, "All target values must be >=0"
        loss = -target * torch.pow(input, 1 - p) / (1 - p) + torch.pow(input, 2 - p) / (
            2 - p
        )
P
Pavol Mulinka 已提交
31
        return torch.mean(loss)
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52


class QuantileLoss(nn.Module):
    """
    Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calcualted as

    Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
    All credits go to `pytorch-forecasting
    <https://pytorch-forecasting.readthedocs.io/en/latest/_modules/pytorch_forecasting/metrics.html#QuantileLoss>`
    """

    def __init__(
        self,
        quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
    ):
        """
        Quantile loss

        Args:
            quantiles: quantiles for metric
        """
P
Pavol Mulinka 已提交
53 54
        super().__init__()
        self.quantiles = quantiles
55 56

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
P
Pavol Mulinka 已提交
57 58 59
        assert input.shape == torch.Size(
            [target.shape[0], len(self.quantiles)]
        ), f"Wrong shape of input, pred_dim of the model that is using QuantileLoss must be equal to number of quantiles, i.e. {len(self.quantiles)}."
60
        target = target.view(-1, 1).float()
61 62 63 64 65 66
        losses = []
        for i, q in enumerate(self.quantiles):
            errors = target - input[..., i]
            losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
        losses = torch.cat(losses, dim=2)

P
Pavol Mulinka 已提交
67
        return torch.mean(losses)
68 69


P
Pavol Mulinka 已提交
70
class ZILNLoss(nn.Module):
71 72 73
    r"""Adjusted implementation of the `Zero Inflated LogNormal loss
    <https://arxiv.org/pdf/1912.07753.pdf>` and its `code
    <https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py>`
P
Pavol Mulinka 已提交
74 75 76 77 78 79 80 81 82 83
    """

    def __init__(self):
        super().__init__()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
84
            input tensor with predictions (not probabilities) with spape (N,3), where N is the batch size
P
Pavol Mulinka 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import ZILNLoss
        >>>
        >>> # REGRESSION
        >>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
        >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
        >>> ZILNLoss()(input, target)
P
Pavol Mulinka 已提交
98
        tensor(1.3114)
P
Pavol Mulinka 已提交
99
        """
P
Pavol Mulinka 已提交
100
        positive = target > 0
P
Pavol Mulinka 已提交
101 102
        positive = positive.float()

P
Pavol Mulinka 已提交
103 104 105
        assert input.shape == torch.Size(
            [target.shape[0], 3]
        ), "Wrong shape of input, pred_dim of the model that is using ZILNLoss must be equal to 3."
P
Pavol Mulinka 已提交
106 107
        positive_input = input[..., :1]

P
Pavol Mulinka 已提交
108 109 110
        classification_loss = F.binary_cross_entropy_with_logits(
            positive_input, positive, reduction="none"
        ).flatten()
P
Pavol Mulinka 已提交
111 112 113 114

        loc = input[..., 1:2]
        scale = torch.maximum(
            F.softplus(input[..., 2:]),
P
Pavol Mulinka 已提交
115 116 117
            torch.sqrt(torch.Tensor([torch.finfo(torch.float32).eps])),
        )
        safe_labels = positive * target + (1 - positive) * torch.ones_like(target)
P
Pavol Mulinka 已提交
118 119

        regression_loss = -torch.mean(
P
Pavol Mulinka 已提交
120 121 122 123 124 125
            positive
            * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(
                safe_labels
            ),
            dim=-1,
        )
P
Pavol Mulinka 已提交
126

P
Pavol Mulinka 已提交
127
        return torch.mean(classification_loss + regression_loss)
P
Pavol Mulinka 已提交
128 129


130
class FocalLoss(nn.Module):
131 132 133
    r"""Implementation of the `focal loss
    <https://arxiv.org/pdf/1708.02002.pdf>`_ for both binary and
    multiclass classification
134

135
    :math:`FL(p_t) = \alpha (1 - p_t)^{\gamma} log(p_t)`
136

137
    where, for a case of a binary classification problem
138

139
    :math:`\begin{equation} p_t= \begin{cases}p, & \text{if $y=1$}.\\1-p, & \text{otherwise}. \end{cases} \end{equation}`
140

141 142 143 144 145 146 147 148 149
    Parameters
    ----------
    alpha: float
        Focal Loss ``alpha`` parameter
    gamma: float
        Focal Loss ``gamma`` parameter
    """

    def __init__(self, alpha: float = 0.25, gamma: float = 1.0):
150 151 152 153
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

154
    def _get_weight(self, p: Tensor, t: Tensor) -> Tensor:
J
jrzaurin 已提交
155 156 157
        pt = p * t + (1 - p) * (1 - t)  # type: ignore
        w = self.alpha * t + (1 - self.alpha) * (1 - t)  # type: ignore
        return (w * (1 - pt).pow(self.gamma)).detach()  # type: ignore
158

159
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
160
        r"""
161 162 163 164 165 166 167 168 169 170
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
171
        >>>
172 173 174
        >>> from pytorch_widedeep.losses import FocalLoss
        >>>
        >>> # BINARY
175
        >>> target = torch.tensor([0, 1, 0, 1]).view(-1, 1)
176
        >>> input = torch.tensor([[0.6, 0.7, 0.3, 0.8]]).t()
177
        >>> FocalLoss()(input, target)
178
        tensor(0.1762)
179 180
        >>>
        >>> # MULTICLASS
181
        >>> target = torch.tensor([1, 0, 2]).view(-1, 1)
182
        >>> input = torch.tensor([[0.2, 0.5, 0.3], [0.8, 0.1, 0.1], [0.7, 0.2, 0.1]])
183
        >>> FocalLoss()(input, target)
184
        tensor(0.2573)
185 186
        """
        input_prob = torch.sigmoid(input)
187
        if input.size(1) == 1:
188
            input_prob = torch.cat([1 - input_prob, input_prob], axis=1)  # type: ignore
189
            num_class = 2
J
jrzaurin 已提交
190
        else:
191
            num_class = input_prob.size(1)
192
        binary_target = torch.eye(num_class)[target.squeeze().long()]
J
jrzaurin 已提交
193 194
        if use_cuda:
            binary_target = binary_target.cuda()
195
        binary_target = binary_target.contiguous()
196 197 198
        weight = self._get_weight(input_prob, binary_target)
        return F.binary_cross_entropy(
            input_prob, binary_target, weight, reduction="mean"
J
jrzaurin 已提交
199
        )
200 201 202


class MSLELoss(nn.Module):
203 204
    r"""mean squared log error"""

205 206 207 208 209
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
        r"""
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import MSLELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> MSLELoss()(input, target)
        tensor(0.1115)
        """
P
Pavol Mulinka 已提交
228 229 230 231 232 233
        assert (
            input.min() >= 0
        ), """All input values must be >=0, if your model is predicting
            values <0 try to enforce positive values by activation function
            on last layer with `trainer.enforce_positive_output=True`"""
        assert target.min() >= 0, "All target values must be >=0"
234 235 236 237
        return self.mse(torch.log(input + 1), torch.log(target + 1))


class RMSELoss(nn.Module):
238 239
    r"""root mean squared error"""

240 241 242 243 244
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
        r"""
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import RMSELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> RMSELoss()(input, target)
        tensor(0.6964)
        """
263 264 265 266
        return torch.sqrt(self.mse(input, target))


class RMSLELoss(nn.Module):
267 268
    r"""root mean squared log error"""

269 270 271 272 273
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
        r"""
        Parameters
        ----------
        input: Tensor
            input tensor with predictions (not probabilities)
        target: Tensor
            target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import RMSLELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> RMSLELoss()(input, target)
        tensor(0.3339)
        """
P
Pavol Mulinka 已提交
292 293 294 295 296 297
        assert (
            input.min() >= 0
        ), """All input values must be >=0, if your model is predicting
            values <0 try to enforce positive values by activation function
            on last layer with `trainer.enforce_positive_output=True`"""
        assert target.min() >= 0, "All target values must be >=0"
298
        return torch.sqrt(self.mse(torch.log(input + 1), torch.log(target + 1)))