From 77f99678f3ffa9aa372a3134603240cc44a0cbe1 Mon Sep 17 00:00:00 2001 From: Javier Rodriguez Zaurin Date: Mon, 9 May 2022 23:00:28 +0200 Subject: [PATCH] Self Supervised seems to be running in a small number of cases. Now I need to test a few more cases. Adjust unit test and re-organise code --- .../scripts/adult_census_self_supervised.py | 12 +- pytorch_widedeep/losses.py | 47 ++++++-- .../preprocessing/tab_preprocessor.py | 8 ++ .../_base_self_supervised_trainer.py | 3 + .../self_supervised_training/_denoise_mlps.py | 96 +++++++--------- .../self_supervised_model.py | 104 ++++++++++++++++-- .../self_supervised_trainer.py | 2 + 7 files changed, 194 insertions(+), 78 deletions(-) diff --git a/examples/scripts/adult_census_self_supervised.py b/examples/scripts/adult_census_self_supervised.py index 822c7bd..aa40887 100644 --- a/examples/scripts/adult_census_self_supervised.py +++ b/examples/scripts/adult_census_self_supervised.py @@ -50,7 +50,10 @@ if __name__ == "__main__": target = df[target].values tab_preprocessor = TabPreprocessor( - cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols # type: ignore[arg-type] + cat_embed_cols=cat_embed_cols, + continuous_cols=continuous_cols, + with_attention=True, + with_cls_token=True, ) X_tab = tab_preprocessor.fit_transform(df) @@ -63,5 +66,10 @@ if __name__ == "__main__": mlp_dropout=0.2, ) - ss_trainer = SelfSupervisedTrainer(tab_mlp) + ss_trainer = SelfSupervisedTrainer( + model=tab_mlp, + preprocessor=tab_preprocessor, + cat_mlp_type="single", + cont_mlp_type="single", + ) ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256) diff --git a/pytorch_widedeep/losses.py b/pytorch_widedeep/losses.py index fb8ac16..08fa810 100644 --- a/pytorch_widedeep/losses.py +++ b/pytorch_widedeep/losses.py @@ -857,22 +857,47 @@ class DenoisingLoss(nn.Module): def forward( self, - x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]], - x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]], + x_cat_and_cat_: Optional[ + Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]] + ], + x_cont_and_cont_: Optional[ + Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]] + ], ) -> Tensor: - if x_cat_and_cat_ is not None: - loss_cat = torch.tensor(0.0) + loss_cat = ( + self._compute_cat_loss(x_cat_and_cat_) + if x_cat_and_cat_ is not None + else torch.tensor(0.0) + ) + loss_cont = ( + self._compute_cont_loss(x_cont_and_cont_) + if x_cont_and_cont_ is not None + else torch.tensor(0.0) + ) + + return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont + + def _compute_cat_loss(self, x_cat_and_cat_): + + loss_cat = torch.tensor(0.0) + if isinstance(x_cat_and_cat_, list): for x, x_ in x_cat_and_cat_: loss_cat += F.cross_entropy(x_, x, reduction=self.reduction) - else: - loss_cat = torch.tensor(0.0) + elif isinstance(x_cat_and_cat_, tuple): + x, x_ = x_cat_and_cat_ + loss_cat += F.cross_entropy(x_, x, reduction=self.reduction) + + return loss_cat + + def _compute_cont_loss(self, x_cont_and_cont_): - if x_cont_and_cont_ is not None: - loss_cont = torch.tensor(0.0) + loss_cont = torch.tensor(0.0) + if isinstance(x_cont_and_cont_, list): for x, x_ in x_cont_and_cont_: loss_cont += F.mse_loss(x_, x, reduction=self.reduction) - else: - loss_cat = torch.tensor(0.0) + elif isinstance(x_cont_and_cont_, tuple): + x, x_ = x_cont_and_cont_ + loss_cont += F.mse_loss(x_, x, reduction=self.reduction) - return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont + return loss_cont diff --git a/pytorch_widedeep/preprocessing/tab_preprocessor.py b/pytorch_widedeep/preprocessing/tab_preprocessor.py index d5f061b..f67ccd9 100644 --- a/pytorch_widedeep/preprocessing/tab_preprocessor.py +++ b/pytorch_widedeep/preprocessing/tab_preprocessor.py @@ -292,6 +292,14 @@ class TabPreprocessor(BasePreprocessor): return df.copy()[self.continuous_cols] def _check_inputs(self): + + if self.with_cls_token and not self.with_attention: + warnings.warn( + "If 'with_cls_token' is set to 'True', 'with_attention' will be automatically ", + "to 'True' if is 'False'", + ) + self.with_attention = True + if (self.cat_embed_cols is None) and (self.continuous_cols is None): raise ValueError( "'cat_embed_cols' and 'continuous_cols' are 'None'. Please, define at least one of the two." diff --git a/pytorch_widedeep/self_supervised_training/_base_self_supervised_trainer.py b/pytorch_widedeep/self_supervised_training/_base_self_supervised_trainer.py index 7ff8edc..2922011 100644 --- a/pytorch_widedeep/self_supervised_training/_base_self_supervised_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_self_supervised_trainer.py @@ -12,6 +12,7 @@ from pytorch_widedeep.callbacks import ( CallbackContainer, LRShedulerCallback, ) +from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor from pytorch_widedeep.self_supervised_training.self_supervised_model import ( SelfSupervisedModel, ) @@ -21,6 +22,7 @@ class BaseSelfSupervisedTrainer(ABC): def __init__( self, model, + preprocessor: TabPreprocessor, optimizer: Optional[Optimizer], lr_scheduler: Optional[LRScheduler], callbacks: Optional[List[Callback]], @@ -38,6 +40,7 @@ class BaseSelfSupervisedTrainer(ABC): self.ss_model = SelfSupervisedModel( model, + preprocessor.label_encoder.encoding_dict, loss_type, projection_head1_dims, projection_head2_dims, diff --git a/pytorch_widedeep/self_supervised_training/_denoise_mlps.py b/pytorch_widedeep/self_supervised_training/_denoise_mlps.py index af887cc..60ba21d 100644 --- a/pytorch_widedeep/self_supervised_training/_denoise_mlps.py +++ b/pytorch_widedeep/self_supervised_training/_denoise_mlps.py @@ -6,23 +6,18 @@ from pytorch_widedeep.models.tabular.mlp._layers import MLP class CatSingleMlp(nn.Module): - def __init__(self, model, activation): + def __init__(self, input_dim, cat_embed_input, column_idx, activation): super(CatSingleMlp, self).__init__() - self.column_idx = model.column_idx - self.cat_embed_input = model.cat_embed_input - self.num_class = model.cat_and_cont_embed.cat_embed.n_tokens - + self.input_dim = input_dim + self.column_idx = column_idx + self.cat_embed_input = cat_embed_input self.activation = activation - mlp_hidden_dims = [ - model.input_dim, - self.num_class * 4, - self.num_class, - ] + self.num_class = sum([ei[1] for ei in cat_embed_input]) self.mlp = MLP( - d_hidden=mlp_hidden_dims, + d_hidden=[input_dim, self.num_class * 4, self.num_class], activation=activation, dropout=0.0, batchnorm=False, @@ -32,30 +27,26 @@ class CatSingleMlp(nn.Module): def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]: - # '-1' because the Label Encoder is designed to leave 0 for unseen - # categories (or padding) during the supervised training. Here we - # are predicting directly the categories, and a categorical target - # has to start from 0 - x = torch.stack( - [ - X[:, self.column_idx[col]].long() - 1 - for col, _, _ in self.cat_embed_input - ], - dim=1, + x = torch.cat( + [X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input] ) - x_ = self.mlp(r_) + cat_r_ = torch.cat( + [r_[:, self.column_idx[col], :] for col, _ in self.cat_embed_input] + ) + + x_ = self.mlp(cat_r_) return x, x_ class CatFeaturesMlp(nn.Module): - def __init__(self, model, activation): + def __init__(self, input_dim, cat_embed_input, column_idx, activation): super(CatFeaturesMlp, self).__init__() - self.column_idx = model.column_idx - self.cat_embed_input = model.cat_embed_input - + self.input_dim = input_dim + self.column_idx = column_idx + self.cat_embed_input = cat_embed_input self.activation = activation self.mlp = nn.ModuleDict( @@ -63,7 +54,7 @@ class CatFeaturesMlp(nn.Module): "mlp_" + col: MLP( d_hidden=[ - model.input_dim, + input_dim, val * 4, val, ], @@ -73,43 +64,33 @@ class CatFeaturesMlp(nn.Module): batchnorm_last=False, linear_first=False, ) - for col, val, _ in model.cat_embed_input + for col, val in self.cat_embed_input } ) def forward(self, X: Tensor, r_: Tensor) -> List[Tuple[Tensor, Tensor]]: - # '-1' because the Label Encoder is designed to leave 0 for unseen - # categories (or padding) during the supervised training. Here we - # are predicting directly the categories, and a categorical target - # has to start from 0 - x = [ - X[:, self.column_idx[col]].long() - 1 for col, _, _ in self.cat_embed_input - ] + x = [X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input] x_ = [ self.mlp["mlp_" + col](r_[:, self.column_idx[col], :]) - for col, _, _ in self.cat_embed_input + for col, _ in self.cat_embed_input ] return list(zip(x, x_)) class ContSingleMlp(nn.Module): - def __init__(self, model, activation): + def __init__(self, input_dim, continuous_cols, column_idx, activation): super(ContSingleMlp, self).__init__() - self.column_idx = model.column_idx - self.continuous_cols = model.continuous_cols - + self.input_dim = input_dim + self.column_idx = column_idx + self.continuous_cols = continuous_cols self.activation = activation self.mlp = MLP( - d_hidden=[ - model.input_dim, - model.input_dim * 2, - 1, - ], + d_hidden=[input_dim, input_dim * 2, 1], activation=activation, dropout=0.0, batchnorm=False, @@ -119,23 +100,26 @@ class ContSingleMlp(nn.Module): def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]: - x = torch.stack( - [X[:, self.column_idx[col]] for col in self.continuous_cols], - dim=1, + x = torch.cat( + [X[:, self.column_idx[col]].float() for col in self.continuous_cols] + ).unsqueeze(1) + + cont_r_ = torch.cat( + [r_[:, self.column_idx[col], :] for col in self.continuous_cols] ) - x_ = self.mlp(r_) + x_ = self.mlp(cont_r_) return x, x_ class ContFeaturesMlp(nn.Module): - def __init__(self, model, activation): + def __init__(self, input_dim, continuous_cols, column_idx, activation): super(ContFeaturesMlp, self).__init__() - self.column_idx = model.column_idx - self.continuous_cols = model.continuous_cols - + self.input_dim = input_dim + self.column_idx = column_idx + self.continuous_cols = continuous_cols self.activation = activation self.mlp = nn.ModuleDict( @@ -143,8 +127,8 @@ class ContFeaturesMlp(nn.Module): "mlp_" + col: MLP( d_hidden=[ - model.input_dim, - model.input_dim * 2, + input_dim, + input_dim * 2, 1, ], activation=activation, @@ -153,7 +137,7 @@ class ContFeaturesMlp(nn.Module): batchnorm_last=False, linear_first=False, ) - for col in model.continuous_cols + for col in self.continuous_cols } ) diff --git a/pytorch_widedeep/self_supervised_training/self_supervised_model.py b/pytorch_widedeep/self_supervised_training/self_supervised_model.py index 8fe5dbc..f25b8c0 100644 --- a/pytorch_widedeep/self_supervised_training/self_supervised_model.py +++ b/pytorch_widedeep/self_supervised_training/self_supervised_model.py @@ -18,6 +18,7 @@ class SelfSupervisedModel(nn.Module): def __init__( self, model: nn.Module, + encoding_dict: Dict[str, Dict[str, int]], loss_type: Literal["contrastive", "denoising", "both"], projection_head1_dims: Optional[List], projection_head2_dims: Optional[List], @@ -29,7 +30,6 @@ class SelfSupervisedModel(nn.Module): super(SelfSupervisedModel, self).__init__() self.model = model - self.loss_type = loss_type self.projection_head1_dims = projection_head1_dims self.projection_head2_dims = projection_head2_dims @@ -38,12 +38,18 @@ class SelfSupervisedModel(nn.Module): self.cont_mlp_type = cont_mlp_type self.denoise_mlps_activation = denoise_mlps_activation + self.cat_embed_input, self.column_idx = self._adjust_if_with_cls_token( + encoding_dict + ) + self.projection_head1, self.projection_head2 = self._set_projection_heads() ( self.denoise_cat_mlp, self.denoise_cont_mlp, ) = self._set_cat_and_cont_denoise_mlps() + self._t = self._tensor_to_subtract(encoding_dict) + def forward( self, X: Tensor ) -> Tuple[ @@ -52,11 +58,19 @@ class SelfSupervisedModel(nn.Module): Optional[Tuple[Tensor, Tensor]], ]: - encoded = self.model.encoder(self.model._get_embeddings(X)) + _X = self._prepare_x(X) + + embed = self.model._get_embeddings(X) + _embed = embed[:, 1:] if self.model.with_cls_token else embed + + encoded = self.model.encoder(_embed) cut_mixed = cut_mix(X) cut_mixed_embed = self.model._get_embeddings(cut_mixed) - cut_mixed_embed_mixed_up = mix_up(cut_mixed_embed) + _cut_mixed_embed = ( + cut_mixed_embed[:, 1:] if self.model.with_cls_token else cut_mixed_embed + ) + cut_mixed_embed_mixed_up = mix_up(_cut_mixed_embed) encoded_ = self.model.encoder(cut_mixed_embed_mixed_up) if self.loss_type in ["contrastive", "both"]: @@ -66,11 +80,11 @@ class SelfSupervisedModel(nn.Module): if self.loss_type in ["denoising", "both"]: if self.model.cat_embed_input is not None: - cat_x_and_x_ = self.denoise_cat_mlp(X, encoded_) + cat_x_and_x_ = self.denoise_cat_mlp(_X, encoded_) else: cat_x_and_x_ = None if self.model.continuous_cols is not None: - cont_x_and_x_ = self.denoise_cont_mlp(X, encoded_) + cont_x_and_x_ = self.denoise_cont_mlp(_X, encoded_) else: cont_x_and_x_ = None @@ -107,13 +121,85 @@ class SelfSupervisedModel(nn.Module): def _set_cat_and_cont_denoise_mlps(self) -> Tuple[nn.Module, nn.Module]: if self.cat_mlp_type == "single": - denoise_cat_mlp = CatSingleMlp(self.model, self.denoise_mlps_activation) + denoise_cat_mlp = CatSingleMlp( + self.model.input_dim, + self.cat_embed_input, + self.column_idx, + self.denoise_mlps_activation, + ) elif self.cat_mlp_type == "multiple": - denoise_cat_mlp = CatFeaturesMlp(self.model, self.denoise_mlps_activation) + denoise_cat_mlp = CatFeaturesMlp( + self.model.input_dim, + self.cat_embed_input, + self.column_idx, + self.denoise_mlps_activation, + ) if self.cont_mlp_type == "single": - denoise_cont_mlp = ContSingleMlp(self.model, self.denoise_mlps_activation) + denoise_cont_mlp = ContSingleMlp( + self.model.input_dim, + self.model.continuous_cols, + self.column_idx, + self.denoise_mlps_activation, + ) elif self.cont_mlp_type == "multiple": - denoise_cont_mlp = ContFeaturesMlp(self.model, self.denoise_mlps_activation) + denoise_cont_mlp = ContFeaturesMlp( + self.model.input_dim, + self.model.continuous_cols, + self.column_idx, + self.denoise_mlps_activation, + ) return denoise_cat_mlp, denoise_cont_mlp + + def _prepare_x(self, X_tab: Tensor) -> Tensor: + + _X_tab = X_tab[:, 1:] if self.model.with_cls_token else X_tab + + return _X_tab - self._t.repeat(X_tab.size(0), 1) + + def _adjust_if_with_cls_token(self, encoding_dict): + if self.model.with_cls_token: + adj_column_idx = { + k: self.model.column_idx[k] - 1 + for k in self.model.column_idx + if k != "cls_token" + } + adj_cat_embed_input = self.model.cat_embed_input[1:] + else: + adj_column_idx = dict(column_idx) + adj_cat_embed_input = self.model.cat_embed_input + + return adj_cat_embed_input, adj_column_idx + + def _set_idx_to_substract(self, encoding_dict) -> Tensor: + + if self.model.with_cls_token: + adj_encoding_dict = { + k: v for k, v in encoding_dict.items() if k != "cls_token" + } + + if self.cat_mlp_type == "multiple": + idx_to_substract: Optional[Dict[str, Dict[str, int]]] = { + k: min(sorted(list(v.values()))) for k, v in adj_encoding_dict.items() + } + + if self.cat_mlp_type == "single": + idx_to_substract = None + + return idx_to_substract + + def _tensor_to_subtract(self, encoding_dict) -> Tensor: + + self.idx_to_substract = self._set_idx_to_substract(encoding_dict) + + _t = torch.zeros(len(self.column_idx)) + + if self.idx_to_substract is not None: + for colname, idx in self.idx_to_substract.items(): + _t[self.column_idx[colname]] = idx + else: + for colname, _ in self.cat_embed_input: + # 0 is reserved for padding, 1 for the '[CLS]' token, if present + _t[self.column_idx[colname]] = 2 if self.model.with_cls_token else 1 + return _t diff --git a/pytorch_widedeep/self_supervised_training/self_supervised_trainer.py b/pytorch_widedeep/self_supervised_training/self_supervised_trainer.py index 1171d38..4beff63 100644 --- a/pytorch_widedeep/self_supervised_training/self_supervised_trainer.py +++ b/pytorch_widedeep/self_supervised_training/self_supervised_trainer.py @@ -18,6 +18,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer): def __init__( self, model, + preprocessor, optimizer: Optional[Optimizer] = None, lr_scheduler: Optional[LRScheduler] = None, callbacks: Optional[List[Callback]] = None, @@ -34,6 +35,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer): ): super().__init__( model=model, + preprocessor=preprocessor, loss_type=loss_type, optimizer=optimizer, lr_scheduler=lr_scheduler, -- GitLab