From cdd674ed314c5addad72b71a14c16df1592f9502 Mon Sep 17 00:00:00 2001 From: Javier Rodriguez Zaurin Date: Tue, 10 May 2022 23:55:15 +0200 Subject: [PATCH] Self Supervised runs for 4 attention-based models. The Perceiver will not be supported and not need to test the MLP-ATtention models --- .../scripts/adult_census_self_supervised.py | 52 ++++++++--- .../self_supervised_training/_denoise_mlps.py | 37 ++++++-- .../self_supervised_model.py | 89 ++++++++----------- 3 files changed, 112 insertions(+), 66 deletions(-) diff --git a/examples/scripts/adult_census_self_supervised.py b/examples/scripts/adult_census_self_supervised.py index aa40887..78c9703 100644 --- a/examples/scripts/adult_census_self_supervised.py +++ b/examples/scripts/adult_census_self_supervised.py @@ -2,7 +2,13 @@ import numpy as np import torch import pandas as pd -from pytorch_widedeep.models import TabTransformer +from pytorch_widedeep.models import ( + SAINT, + TabPerceiver, + FTTransformer, + TabFastFormer, + TabTransformer, +) from pytorch_widedeep.datasets import load_adult from pytorch_widedeep.preprocessing import TabPreprocessor from pytorch_widedeep.self_supervised_training.self_supervised_trainer import ( @@ -57,19 +63,45 @@ if __name__ == "__main__": ) X_tab = tab_preprocessor.fit_transform(df) - tab_mlp = TabTransformer( + tab_transformer = TabTransformer( column_idx=tab_preprocessor.column_idx, cat_embed_input=tab_preprocessor.cat_embed_input, continuous_cols=continuous_cols, embed_continuous=True, - mlp_hidden_dims=[200, 100], - mlp_dropout=0.2, + n_blocks=4, ) - ss_trainer = SelfSupervisedTrainer( - model=tab_mlp, - preprocessor=tab_preprocessor, - cat_mlp_type="single", - cont_mlp_type="single", + saint = SAINT( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + continuous_cols=continuous_cols, + cont_norm_layer="batchnorm", + n_blocks=4, ) - ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256) + + tab_fastformer = TabFastFormer( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + continuous_cols=continuous_cols, + n_blocks=4, + n_heads=4, + share_qv_weights=False, + share_weights=False, + ) + + ft_transformer = FTTransformer( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + continuous_cols=continuous_cols, + input_dim=32, + kv_compression_factor=0.5, + n_blocks=3, + n_heads=4, + ) + + for transformer_model in [tab_transformer, saint, tab_fastformer, ft_transformer]: + ss_trainer = SelfSupervisedTrainer( + model=transformer_model, + preprocessor=tab_preprocessor, + ) + ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256) diff --git a/pytorch_widedeep/self_supervised_training/_denoise_mlps.py b/pytorch_widedeep/self_supervised_training/_denoise_mlps.py index 60ba21d..0771a40 100644 --- a/pytorch_widedeep/self_supervised_training/_denoise_mlps.py +++ b/pytorch_widedeep/self_supervised_training/_denoise_mlps.py @@ -14,7 +14,7 @@ class CatSingleMlp(nn.Module): self.cat_embed_input = cat_embed_input self.activation = activation - self.num_class = sum([ei[1] for ei in cat_embed_input]) + self.num_class = sum([ei[1] for ei in cat_embed_input if e[0] != "cls_token"]) self.mlp = MLP( d_hidden=[input_dim, self.num_class * 4, self.num_class], @@ -28,11 +28,19 @@ class CatSingleMlp(nn.Module): def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]: x = torch.cat( - [X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input] + [ + X[:, self.column_idx[col]].long() + for col, _ in self.cat_embed_input + if col != "cls_token" + ] ) cat_r_ = torch.cat( - [r_[:, self.column_idx[col], :] for col, _ in self.cat_embed_input] + [ + r_[:, self.column_idx[col], :] + for col, _ in self.cat_embed_input + if col != "cls_token" + ] ) x_ = self.mlp(cat_r_) @@ -65,16 +73,22 @@ class CatFeaturesMlp(nn.Module): linear_first=False, ) for col, val in self.cat_embed_input + if col != "cls_token" } ) def forward(self, X: Tensor, r_: Tensor) -> List[Tuple[Tensor, Tensor]]: - x = [X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input] + x = [ + X[:, self.column_idx[col]].long() + for col, _ in self.cat_embed_input + if col != "cls_token" + ] x_ = [ self.mlp["mlp_" + col](r_[:, self.column_idx[col], :]) for col, _ in self.cat_embed_input + if col != "cls_token" ] return list(zip(x, x_)) @@ -101,11 +115,19 @@ class ContSingleMlp(nn.Module): def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]: x = torch.cat( - [X[:, self.column_idx[col]].float() for col in self.continuous_cols] + [ + X[:, self.column_idx[col]].float() + for col in self.continuous_cols + if col != "cls_token" + ] ).unsqueeze(1) cont_r_ = torch.cat( - [r_[:, self.column_idx[col], :] for col in self.continuous_cols] + [ + r_[:, self.column_idx[col], :] + for col in self.continuous_cols + if col != "cls_token" + ] ) x_ = self.mlp(cont_r_) @@ -138,6 +160,7 @@ class ContFeaturesMlp(nn.Module): linear_first=False, ) for col in self.continuous_cols + if col != "cls_token" } ) @@ -146,11 +169,13 @@ class ContFeaturesMlp(nn.Module): x = [ X[:, self.column_idx[col]].unsqueeze(1).float() for col in self.continuous_cols + if col != "cls_token" ] x_ = [ self.mlp["mlp_" + col](r_[:, self.column_idx[col]]) for col in self.continuous_cols + if col != "cls_token" ] return list(zip(x, x_)) diff --git a/pytorch_widedeep/self_supervised_training/self_supervised_model.py b/pytorch_widedeep/self_supervised_training/self_supervised_model.py index f25b8c0..db376be 100644 --- a/pytorch_widedeep/self_supervised_training/self_supervised_model.py +++ b/pytorch_widedeep/self_supervised_training/self_supervised_model.py @@ -38,9 +38,8 @@ class SelfSupervisedModel(nn.Module): self.cont_mlp_type = cont_mlp_type self.denoise_mlps_activation = denoise_mlps_activation - self.cat_embed_input, self.column_idx = self._adjust_if_with_cls_token( - encoding_dict - ) + if self.loss_type in ["denoising", "both"]: + self._t = self._tensor_to_subtract(encoding_dict) self.projection_head1, self.projection_head2 = self._set_projection_heads() ( @@ -48,8 +47,6 @@ class SelfSupervisedModel(nn.Module): self.denoise_cont_mlp, ) = self._set_cat_and_cont_denoise_mlps() - self._t = self._tensor_to_subtract(encoding_dict) - def forward( self, X: Tensor ) -> Tuple[ @@ -58,27 +55,42 @@ class SelfSupervisedModel(nn.Module): Optional[Tuple[Tensor, Tensor]], ]: - _X = self._prepare_x(X) - + # "uncorrupted branch" embed = self.model._get_embeddings(X) - _embed = embed[:, 1:] if self.model.with_cls_token else embed - - encoded = self.model.encoder(_embed) + if self.model.with_cls_token: + embed[:, 0, :] = 0.0 + encoded = self.model.encoder(embed) + # cut_mix and mix_up branch cut_mixed = cut_mix(X) cut_mixed_embed = self.model._get_embeddings(cut_mixed) - _cut_mixed_embed = ( - cut_mixed_embed[:, 1:] if self.model.with_cls_token else cut_mixed_embed - ) - cut_mixed_embed_mixed_up = mix_up(_cut_mixed_embed) + if self.model.with_cls_token: + cut_mixed_embed[:, 0, :] = 0.0 + cut_mixed_embed_mixed_up = mix_up(cut_mixed_embed) encoded_ = self.model.encoder(cut_mixed_embed_mixed_up) + # projections for constrastive loss if self.loss_type in ["contrastive", "both"]: - g_projs = (self.projection_head1(encoded), self.projection_head2(encoded_)) + if self.model.with_cls_token: + g_projs = ( + self.projection_head1(encoded[:, 1:, :]), + self.projection_head2(encoded_[:, 1:, :]), + ) + else: + g_projs = ( + self.projection_head1(encoded), + self.projection_head2(encoded_), + ) else: g_projs = None + # mlps for denoising loss if self.loss_type in ["denoising", "both"]: + + _X = X - self._t.repeat(X.size(0), 1) + if self.model.with_cls_token: + _X[:, 0] = 0.0 + if self.model.cat_embed_input is not None: cat_x_and_x_ = self.denoise_cat_mlp(_X, encoded_) else: @@ -123,15 +135,15 @@ class SelfSupervisedModel(nn.Module): if self.cat_mlp_type == "single": denoise_cat_mlp = CatSingleMlp( self.model.input_dim, - self.cat_embed_input, - self.column_idx, + self.model.cat_embed_input, + self.model.column_idx, self.denoise_mlps_activation, ) elif self.cat_mlp_type == "multiple": denoise_cat_mlp = CatFeaturesMlp( self.model.input_dim, - self.cat_embed_input, - self.column_idx, + self.model.cat_embed_input, + self.model.column_idx, self.denoise_mlps_activation, ) @@ -139,49 +151,24 @@ class SelfSupervisedModel(nn.Module): denoise_cont_mlp = ContSingleMlp( self.model.input_dim, self.model.continuous_cols, - self.column_idx, + self.model.column_idx, self.denoise_mlps_activation, ) elif self.cont_mlp_type == "multiple": denoise_cont_mlp = ContFeaturesMlp( self.model.input_dim, self.model.continuous_cols, - self.column_idx, + self.model.column_idx, self.denoise_mlps_activation, ) return denoise_cat_mlp, denoise_cont_mlp - def _prepare_x(self, X_tab: Tensor) -> Tensor: - - _X_tab = X_tab[:, 1:] if self.model.with_cls_token else X_tab - - return _X_tab - self._t.repeat(X_tab.size(0), 1) - - def _adjust_if_with_cls_token(self, encoding_dict): - if self.model.with_cls_token: - adj_column_idx = { - k: self.model.column_idx[k] - 1 - for k in self.model.column_idx - if k != "cls_token" - } - adj_cat_embed_input = self.model.cat_embed_input[1:] - else: - adj_column_idx = dict(column_idx) - adj_cat_embed_input = self.model.cat_embed_input - - return adj_cat_embed_input, adj_column_idx - def _set_idx_to_substract(self, encoding_dict) -> Tensor: - if self.model.with_cls_token: - adj_encoding_dict = { - k: v for k, v in encoding_dict.items() if k != "cls_token" - } - if self.cat_mlp_type == "multiple": idx_to_substract: Optional[Dict[str, Dict[str, int]]] = { - k: min(sorted(list(v.values()))) for k, v in adj_encoding_dict.items() + k: min(sorted(list(v.values()))) for k, v in encoding_dict.items() } if self.cat_mlp_type == "single": @@ -193,13 +180,15 @@ class SelfSupervisedModel(nn.Module): self.idx_to_substract = self._set_idx_to_substract(encoding_dict) - _t = torch.zeros(len(self.column_idx)) + _t = torch.zeros(len(self.model.column_idx)) if self.idx_to_substract is not None: for colname, idx in self.idx_to_substract.items(): - _t[self.column_idx[colname]] = idx + _t[self.model.column_idx[colname]] = idx else: for colname, _ in self.cat_embed_input: # 0 is reserved for padding, 1 for the '[CLS]' token, if present - _t[self.column_idx[colname]] = 2 if self.model.with_cls_token else 1 + _t[self.model.column_idx[colname]] = ( + 2 if self.model.with_cls_token else 1 + ) return _t -- GitLab