提交 fc796f92 编写于 作者: P Pavol Mulinka

minor fixes after rebase

上级 1a81bdb9
......@@ -30,6 +30,7 @@ class TweedieLoss(nn.Module):
)
if weight is not None:
loss *= weight.expand_as(loss)
return torch.mean(loss)
......@@ -58,20 +59,20 @@ class QuantileLoss(nn.Module):
assert input.shape == torch.Size(
[target.shape[0], len(self.quantiles)]
), f"Wrong shape of input, pred_dim of the model that is using QuantileLoss must be equal to number of quantiles, i.e. {len(self.quantiles)}."
), f"""Wrong shape of input, pred_dim of the model that is using QuantileLoss must be
equal to number of quantiles, i.e. {len(self.quantiles)}."""
target = target.view(-1, 1).float()
losses = []
for i, q in enumerate(self.quantiles):
errors = target - input[..., i]
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
loss = torch.cat(losses, dim=2)
return torch.mean(loss)
loss = torch.cat(losses, dim=2)
if weight is not None:
losses *= weight.expand_as(losses)
return torch.mean(losses)
return torch.mean(losses)
class ZILNLoss(nn.Module):
......@@ -206,11 +207,74 @@ class FocalLoss(nn.Module):
binary_target = binary_target.cuda()
binary_target = binary_target.contiguous()
weight = self._get_weight(input_prob, binary_target)
return F.binary_cross_entropy(
input_prob, binary_target, weight, reduction="mean"
)
class L1Loss(nn.Module):
r"""Based on
`Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021).
Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.`
and their `implementation
<https://github.com/YyzHarry/imbalanced-regression>`
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor, target: Tensor, weight: Union[None, Tensor]=None) -> Tensor:
r"""
Parameters
----------
input: Tensor
input tensor with predictions (not probabilities)
target: Tensor
target tensor with the actual classes
Examples
--------
"""
loss = F.l1_loss(input, target, reduction='none')
if weight is not None:
loss *= weight.expand_as(loss)
loss = torch.mean(loss)
return loss
class MSEloss(nn.Module):
r"""Based on
`Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021).
Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.`
and their `implementation
<https://github.com/YyzHarry/imbalanced-regression>`
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor, target: Tensor, weight: Union[None, Tensor]=None) -> Tensor:
r"""
Parameters
----------
input: Tensor
input tensor with predictions (not probabilities)
target: Tensor
target tensor with the actual classes
Examples
--------
"""
loss = (input - target) ** 2
if weight is not None:
loss *= weight.expand_as(loss)
loss = torch.mean(loss)
return loss
class MSLELoss(nn.Module):
r"""mean squared log error"""
......@@ -243,7 +307,13 @@ class MSLELoss(nn.Module):
values <0 try to enforce positive values by activation function
on last layer with `trainer.enforce_positive_output=True`"""
assert target.min() >= 0, "All target values must be >=0"
return self.mse(torch.log(input + 1), torch.log(target + 1))
loss = (torch.log(input + 1) - torch.log(target + 1)) ** 2
if weight is not None:
loss *= weight.expand_as(loss)
loss = torch.mean(loss)
return loss
class RMSELoss(nn.Module):
......@@ -251,7 +321,6 @@ class RMSELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, input: Tensor, target: Tensor, weight: Union[None, Tensor]=None) -> Tensor:
r"""
......@@ -272,7 +341,13 @@ class RMSELoss(nn.Module):
>>> RMSELoss()(input, target)
tensor(0.6964)
"""
return torch.sqrt(self.mse(input, target))
loss = (input - target) ** 2
if weight is not None:
loss *= weight.expand_as(loss)
loss = torch.mean(loss)
loss = torch.sqrt(loss)
return loss
class RMSLELoss(nn.Module):
......@@ -308,41 +383,16 @@ class RMSLELoss(nn.Module):
on last layer with `trainer.enforce_positive_output=True`"""
assert target.min() >= 0, "All target values must be >=0"
return torch.sqrt(self.mse(torch.log(input + 1), torch.log(target + 1)))
class MSEloss(nn.Module):
r"""Based on
`Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021).
Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.`
and their `implementation
<https://github.com/YyzHarry/imbalanced-regression>`
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor, target: Tensor, weight: Union[None, Tensor]=None) -> Tensor:
r"""
Parameters
----------
input: Tensor
input tensor with predictions (not probabilities)
target: Tensor
target tensor with the actual classes
Examples
--------
"""
loss = (input - target) ** 2
loss = (torch.log(input + 1) - torch.log(target + 1)) ** 2
if weight is not None:
loss *= weight.expand_as(loss)
loss = torch.mean(loss)
loss = torch.sqrt(loss)
return loss
class L1Loss(nn.Module):
class FocalL1Loss(nn.Module):
r"""Based on
`Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021).
Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.`
......@@ -354,7 +404,7 @@ class L1Loss(nn.Module):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, input: Tensor, target: Tensor, weight: Union[None, Tensor]=None) -> Tensor:
def forward(self, input: Tensor, target: Tensor, weight: Union[None, Tensor]=None, activate='sigmoid', beta=.2, gamma=1) -> Tensor:
r"""
Parameters
----------
......@@ -367,6 +417,8 @@ class L1Loss(nn.Module):
--------
"""
loss = F.l1_loss(input, target, reduction='none')
loss *= (torch.tanh(beta * torch.abs(input - target))) ** gamma if activate == 'tanh' else \
(2 * torch.sigmoid(beta * torch.abs(input - target)) - 1) ** gamma
if weight is not None:
loss *= weight.expand_as(loss)
loss = torch.mean(loss)
......@@ -405,7 +457,7 @@ class L1Loss(nn.Module):
return loss
class FocalL1Loss(nn.Module):
class FocalRMSELoss(nn.Module):
r"""Based on
`Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021).
Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.`
......@@ -429,12 +481,13 @@ class FocalL1Loss(nn.Module):
Examples
--------
"""
loss = F.l1_loss(input, target, reduction='none')
loss = (input - target) ** 2
loss *= (torch.tanh(beta * torch.abs(input - target))) ** gamma if activate == 'tanh' else \
(2 * torch.sigmoid(beta * torch.abs(input - target)) - 1) ** gamma
if weight is not None:
loss *= weight.expand_as(loss)
loss = torch.mean(loss)
loss = torch.sqrt(loss)
return loss
......
......@@ -33,6 +33,7 @@ class _LossAliases:
"root_mean_squared_log_error": ["root_mean_squared_log_error", "rmsle"],
"zero_inflated_lognormal": ["zero_inflated_lognormal", "ziln"],
"focalmse": ["focalmse"],
"focalrmse": ["focalrmse"],
"focall1": ["focall1"],
"huber": ["huber"],
"quantile": ["quantile"],
......@@ -80,6 +81,7 @@ class _ObjectiveToMethod:
"tweedie": "regression",
"quantile": "qregression",
"focalmse": "regression",
"focalrmse": "regression",
"focall1": "regression",
"huber": "regression",
}
......
......@@ -6,10 +6,10 @@ from torch.utils.data import TensorDataset
from sklearn.model_selection import train_test_split
from pytorch_widedeep.losses import (
L1Loss,
MSEloss,
MSLELoss,
RMSELoss,
ZILNLoss,
FocalLoss,
RMSLELoss,
TweedieLoss,
QuantileLoss,
......@@ -17,6 +17,9 @@ from pytorch_widedeep.losses import (
L1Loss,
FocalMSELoss,
FocalL1Loss,
FocalMSELoss,
FocalRMSELoss,
FocalLoss,
HuberLoss,
MSEloss,
)
......@@ -342,5 +345,7 @@ def alias_to_loss(loss_fn: str, **kwargs): # noqa: C901
return FocalL1Loss()
if loss_fn in _LossAliases.get("focalmse"):
return FocalMSELoss()
if loss_fn in _LossAliases.get("focalrmse"):
return FocalRMSELoss()
if "focal_loss" in loss_fn:
return FocalLoss(**kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册