From deafc3516566295d6f0ad1f0fdf7f81481dadcfe Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Tue, 14 Dec 2021 13:22:31 +0000 Subject: [PATCH] rebased to master --- pytorch_widedeep/losses.py | 194 ++++++++++++++++++++++- pytorch_widedeep/training/_wd_dataset.py | 67 +++++++- pytorch_widedeep/training/trainer.py | 12 +- 3 files changed, 261 insertions(+), 12 deletions(-) diff --git a/pytorch_widedeep/losses.py b/pytorch_widedeep/losses.py index cdc8ffd..ac08dbb 100644 --- a/pytorch_widedeep/losses.py +++ b/pytorch_widedeep/losses.py @@ -18,7 +18,7 @@ class TweedieLoss(nn.Module): def __init__(self): super().__init__() - def forward(self, input: Tensor, target: Tensor, p=1.5) -> Tensor: + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor], p=1.5) -> Tensor: assert ( input.min() > 0 ), """All input values must be >=0, if your model is predicting @@ -28,6 +28,8 @@ class TweedieLoss(nn.Module): loss = -target * torch.pow(input, 1 - p) / (1 - p) + torch.pow(input, 2 - p) / ( 2 - p ) + if ldsweight is not None: + loss *= ldsweight.expand_as(loss) return torch.mean(loss) @@ -53,11 +55,18 @@ class QuantileLoss(nn.Module): super().__init__() self.quantiles = quantiles +<<<<<<< HEAD def forward(self, input: Tensor, target: Tensor) -> Tensor: assert input.shape == torch.Size([target.shape[0], len(self.quantiles)]), ( f"Wrong shape of input, pred_dim of the model that is using QuantileLoss must be equal " f"to number of quantiles, i.e. {len(self.quantiles)}." ) +======= + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor]) -> Tensor: + assert input.shape == torch.Size( + [target.shape[0], len(self.quantiles)] + ), f"Wrong shape of input, pred_dim of the model that is using QuantileLoss must be equal to number of quantiles, i.e. {len(self.quantiles)}." +>>>>>>> lds added - not tested target = target.view(-1, 1).float() losses = [] for i, q in enumerate(self.quantiles): @@ -65,7 +74,13 @@ class QuantileLoss(nn.Module): losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) loss = torch.cat(losses, dim=2) +<<<<<<< HEAD return torch.mean(loss) +======= + if ldsweight is not None: + losses *= ldsweight.expand_as(losses) + return torch.mean(losses) +>>>>>>> lds added - not tested class ZILNLoss(nn.Module): @@ -77,7 +92,7 @@ class ZILNLoss(nn.Module): def __init__(self): super().__init__() - def forward(self, input: Tensor, target: Tensor) -> Tensor: + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor]) -> Tensor: r""" Parameters ---------- @@ -104,6 +119,8 @@ class ZILNLoss(nn.Module): assert input.shape == torch.Size( [target.shape[0], 3] ), "Wrong shape of input, pred_dim of the model that is using ZILNLoss must be equal to 3." + assert ldsweight is not None, "LDS is not implemented for ZILNLoss yet" + positive_input = input[..., :1] classification_loss = F.binary_cross_entropy_with_logits( @@ -211,7 +228,7 @@ class MSLELoss(nn.Module): super().__init__() self.mse = nn.MSELoss() - def forward(self, input: Tensor, target: Tensor) -> Tensor: + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor]) -> Tensor: r""" Parameters ---------- @@ -246,7 +263,7 @@ class RMSELoss(nn.Module): super().__init__() self.mse = nn.MSELoss() - def forward(self, input: Tensor, target: Tensor) -> Tensor: + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor]) -> Tensor: r""" Parameters ---------- @@ -275,7 +292,7 @@ class RMSLELoss(nn.Module): super().__init__() self.mse = nn.MSELoss() - def forward(self, input: Tensor, target: Tensor) -> Tensor: + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor]) -> Tensor: r""" Parameters ---------- @@ -300,4 +317,171 @@ class RMSLELoss(nn.Module): values <0 try to enforce positive values by activation function on last layer with `trainer.enforce_positive_output=True`""" assert target.min() >= 0, "All target values must be >=0" + return torch.sqrt(self.mse(torch.log(input + 1), torch.log(target + 1))) + + +class MSEloss(nn.Module): + r"""Based on + `Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021). + Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.` + and their `implementation + ` + """ + + def __init__(self): + super().__init__() + self.mse = nn.MSELoss() + + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor]) -> Tensor: + r""" + Parameters + ---------- + input: Tensor + input tensor with predictions (not probabilities) + target: Tensor + target tensor with the actual classes + + Examples + -------- + """ + loss = (input - target) ** 2 + if ldsweight is not None: + loss *= ldsweight.expand_as(loss) + loss = torch.mean(loss) + + return loss + + +class L1Loss(nn.Module): + r"""Based on + `Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021). + Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.` + and their `implementation + ` + """ + + def __init__(self): + super().__init__() + self.mse = nn.MSELoss() + + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor]) -> Tensor: + r""" + Parameters + ---------- + input: Tensor + input tensor with predictions (not probabilities) + target: Tensor + target tensor with the actual classes + + Examples + -------- + """ + loss = F.l1_loss(input, target, reduction='none') + if ldsweight is not None: + loss *= ldsweight.expand_as(loss) + loss = torch.mean(loss) + + return loss + + +class FocalMSELoss(nn.Module): + r"""Based on + `Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021). + Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.` + and their `implementation + ` + """ + + def __init__(self): + super().__init__() + self.mse = nn.MSELoss() + + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor], activate='sigmoid', beta=.2, gamma=1) -> Tensor: + r""" + Parameters + ---------- + input: Tensor + input tensor with predictions (not probabilities) + target: Tensor + target tensor with the actual classes + + Examples + -------- + """ + loss = (input - target) ** 2 + loss *= (torch.tanh(beta * torch.abs(input - target))) ** gamma if activate == 'tanh' else \ + (2 * torch.sigmoid(beta * torch.abs(input - target)) - 1) ** gamma + if ldsweight is not None: + loss *= ldsweight.expand_as(loss) + loss = torch.mean(loss) + + return loss + + +class FocalL1Loss(nn.Module): + r"""Based on + `Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021). + Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.` + and their `implementation + ` + """ + + def __init__(self): + super().__init__() + self.mse = nn.MSELoss() + + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor], activate='sigmoid', beta=.2, gamma=1) -> Tensor: + r""" + Parameters + ---------- + input: Tensor + input tensor with predictions (not probabilities) + target: Tensor + target tensor with the actual classes + + Examples + -------- + """ + loss = F.l1_loss(input, target, reduction='none') + loss *= (torch.tanh(beta * torch.abs(input - target))) ** gamma if activate == 'tanh' else \ + (2 * torch.sigmoid(beta * torch.abs(input - target)) - 1) ** gamma + if ldsweight is not None: + loss *= ldsweight.expand_as(loss) + loss = torch.mean(loss) + + return loss + + +class HuberLoss(nn.Module): + r"""Based on + `Yang, Y., Zha, K., Chen, Y. C., Wang, H., & Katabi, D. (2021). + Delving into Deep Imbalanced Regression. arXiv preprint arXiv:2102.09554.` + and their `implementation + ` + """ + + def __init__(self): + super().__init__() + self.mse = nn.MSELoss() + + def forward(self, input: Tensor, target: Tensor, ldsweight: Union[None, Tensor], beta=.1) -> Tensor: + r""" + Parameters + ---------- + input: Tensor + input tensor with predictions (not probabilities) + target: Tensor + target tensor with the actual classes + + Examples + -------- + """ + l1_loss = torch.abs(input - target) + cond = l1_loss < beta + loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta) + if ldsweight is not None: + loss *= ldsweight.expand_as(loss) + loss = torch.mean(loss) + + return loss diff --git a/pytorch_widedeep/training/_wd_dataset.py b/pytorch_widedeep/training/_wd_dataset.py index e5f12da..ed15568 100644 --- a/pytorch_widedeep/training/_wd_dataset.py +++ b/pytorch_widedeep/training/_wd_dataset.py @@ -2,9 +2,28 @@ import numpy as np import torch from sklearn.utils import Bunch from torch.utils.data import Dataset +from scipy.ndimage import convolve1d +from scipy.ndimage import gaussian_filter1d +from scipy.signal.windows import triang from pytorch_widedeep.wdtypes import * # noqa: F403 +# TODO assert to limit the usage of LDS only for single value regression objective + +def get_lds_kernel_window(kernel, ks, sigma): + assert kernel in ['gaussian', 'triang', 'laplace'] + half_ks = (ks - 1) // 2 + if kernel == 'gaussian': + base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks + kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) + elif kernel == 'triang': + kernel_window = triang(ks) + else: + laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) + kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1))) + + return kernel_window + class WideDeepDataset(Dataset): r""" @@ -34,6 +53,12 @@ class WideDeepDataset(Dataset): X_img: Optional[np.ndarray] = None, target: Optional[np.ndarray] = None, transforms: Optional[Any] = None, + lds: bool = False, + lds_kernel: str = "gaussian", + lds_ks: int = 5, + lds_sigma: int = 2, + reweight: str = None, + Ymax: Optional[float] = None, ): super(WideDeepDataset, self).__init__() self.X_wide = X_wide @@ -47,7 +72,12 @@ class WideDeepDataset(Dataset): ] else: self.transforms_names = [] + self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma) self.Y = target + if Ymax is None: + self.Ymax = max(target) + else: + self.Ymax = Ymax def __getitem__(self, idx: int): # noqa: C901 X = Bunch() @@ -86,11 +116,46 @@ class WideDeepDataset(Dataset): # fill the Bunch X.deepimage = xdi if self.Y is not None: + weight = np.asarray([self.weights[idx]]).astype("float32") if self.weights is not None else self.weights y = self.Y[idx] - return X, y + return X, y, weight else: return X + def _prepare_weights(self, reweight, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2): + assert reweight in {None, "inverse", "sqrt_inv"} + assert reweight != None if lds else True, \ + "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS" + if reweight is None: + return None + else: + max_target = self.Ymax + value_dict = {x: 0 for x in range(max_target)} + labels = self.Y + # mbr + for label in labels: + value_dict[min(max_target - 1, int(label))] += 1 + if reweight == "sqrt_inv": + value_dict = {k: np.sqrt(v) for k, v in value_dict.items()} + elif reweight == "inverse": + value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight + num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels] + if not len(num_per_label) or reweight == "none": + return None + print(f"Using re-weighting: [{reweight.upper()}]") + + if lds: + lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma) + print(f"Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})") + smoothed_value = convolve1d( + np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode="constant") + num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels] + + weights = [np.float32(1 / x) for x in num_per_label] + scaling = len(weights) / np.sum(weights) + weights = [scaling * x for x in weights] + return weights + def __len__(self): if self.X_wide is not None: return len(self.X_wide) diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 2a3bd56..5a06bde 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -612,9 +612,9 @@ class Trainer: self.train_running_loss = 0.0 with trange(train_steps, disable=self.verbose != 1) as t: - for batch_idx, (data, targett) in zip(t, train_loader): + for batch_idx, (data, targett, weight) in zip(t, train_loader): t.set_description("epoch %i" % (epoch + 1)) - train_score, train_loss = self._train_step(data, targett, batch_idx) + train_score, train_loss = self._train_step(data, targett, batch_idx, weight) print_loss_and_metric(t, train_loss, train_score) self.callback_container.on_batch_end(batch=batch_idx) epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train") @@ -626,7 +626,7 @@ class Trainer: self.callback_container.on_eval_begin() self.valid_running_loss = 0.0 with trange(eval_steps, disable=self.verbose != 1) as v: - for i, (data, targett) in zip(v, eval_loader): + for i, (data, targett, weight) in zip(v, eval_loader): v.set_description("valid") val_score, val_loss = self._eval_step(data, targett, i) print_loss_and_metric(v, val_loss, val_score) @@ -1150,7 +1150,7 @@ class Trainer: self.model.deepimage, "deepimage", loader, n_epochs, max_lr ) - def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int): + def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int, weight: Union[None, Tensor]): self.model.train() X = {k: v.cuda() for k, v in data.items()} if use_cuda else data y = ( @@ -1166,7 +1166,7 @@ class Trainer: loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1] score = self._get_score(y_pred[0], y) else: - loss = self.loss_fn(y_pred, y) + loss = self.loss_fn(y_pred, y, weight=weight) score = self._get_score(y_pred, y) # TODO raise exception if the loss is exploding with non scaled target values loss.backward() @@ -1220,7 +1220,7 @@ class Trainer: self.model.eval() tabnet_backbone = list(self.model.deeptabular.children())[0] feat_imp = np.zeros((tabnet_backbone.embed_and_cont_dim)) # type: ignore[arg-type] - for data, target in loader: + for data, target, weight in loader: X = data["deeptabular"].to(device) y = target.view(-1, 1).float() if self.method != "multiclass" else target y = y.to(device) -- GitLab