提交 cdd674ed 编写于 作者: J Javier Rodriguez Zaurin

Self Supervised runs for 4 attention-based models. The Perceiver will not be...

Self Supervised runs for 4 attention-based models. The Perceiver will not be supported and not need to test the MLP-ATtention models
上级 77f99678
......@@ -2,7 +2,13 @@ import numpy as np
import torch
import pandas as pd
from pytorch_widedeep.models import TabTransformer
from pytorch_widedeep.models import (
SAINT,
TabPerceiver,
FTTransformer,
TabFastFormer,
TabTransformer,
)
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.self_supervised_training.self_supervised_trainer import (
......@@ -57,19 +63,45 @@ if __name__ == "__main__":
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabTransformer(
tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
embed_continuous=True,
mlp_hidden_dims=[200, 100],
mlp_dropout=0.2,
n_blocks=4,
)
ss_trainer = SelfSupervisedTrainer(
model=tab_mlp,
preprocessor=tab_preprocessor,
cat_mlp_type="single",
cont_mlp_type="single",
saint = SAINT(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
cont_norm_layer="batchnorm",
n_blocks=4,
)
ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256)
tab_fastformer = TabFastFormer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
n_blocks=4,
n_heads=4,
share_qv_weights=False,
share_weights=False,
)
ft_transformer = FTTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
input_dim=32,
kv_compression_factor=0.5,
n_blocks=3,
n_heads=4,
)
for transformer_model in [tab_transformer, saint, tab_fastformer, ft_transformer]:
ss_trainer = SelfSupervisedTrainer(
model=transformer_model,
preprocessor=tab_preprocessor,
)
ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256)
......@@ -14,7 +14,7 @@ class CatSingleMlp(nn.Module):
self.cat_embed_input = cat_embed_input
self.activation = activation
self.num_class = sum([ei[1] for ei in cat_embed_input])
self.num_class = sum([ei[1] for ei in cat_embed_input if e[0] != "cls_token"])
self.mlp = MLP(
d_hidden=[input_dim, self.num_class * 4, self.num_class],
......@@ -28,11 +28,19 @@ class CatSingleMlp(nn.Module):
def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.cat(
[X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input]
[
X[:, self.column_idx[col]].long()
for col, _ in self.cat_embed_input
if col != "cls_token"
]
)
cat_r_ = torch.cat(
[r_[:, self.column_idx[col], :] for col, _ in self.cat_embed_input]
[
r_[:, self.column_idx[col], :]
for col, _ in self.cat_embed_input
if col != "cls_token"
]
)
x_ = self.mlp(cat_r_)
......@@ -65,16 +73,22 @@ class CatFeaturesMlp(nn.Module):
linear_first=False,
)
for col, val in self.cat_embed_input
if col != "cls_token"
}
)
def forward(self, X: Tensor, r_: Tensor) -> List[Tuple[Tensor, Tensor]]:
x = [X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input]
x = [
X[:, self.column_idx[col]].long()
for col, _ in self.cat_embed_input
if col != "cls_token"
]
x_ = [
self.mlp["mlp_" + col](r_[:, self.column_idx[col], :])
for col, _ in self.cat_embed_input
if col != "cls_token"
]
return list(zip(x, x_))
......@@ -101,11 +115,19 @@ class ContSingleMlp(nn.Module):
def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.cat(
[X[:, self.column_idx[col]].float() for col in self.continuous_cols]
[
X[:, self.column_idx[col]].float()
for col in self.continuous_cols
if col != "cls_token"
]
).unsqueeze(1)
cont_r_ = torch.cat(
[r_[:, self.column_idx[col], :] for col in self.continuous_cols]
[
r_[:, self.column_idx[col], :]
for col in self.continuous_cols
if col != "cls_token"
]
)
x_ = self.mlp(cont_r_)
......@@ -138,6 +160,7 @@ class ContFeaturesMlp(nn.Module):
linear_first=False,
)
for col in self.continuous_cols
if col != "cls_token"
}
)
......@@ -146,11 +169,13 @@ class ContFeaturesMlp(nn.Module):
x = [
X[:, self.column_idx[col]].unsqueeze(1).float()
for col in self.continuous_cols
if col != "cls_token"
]
x_ = [
self.mlp["mlp_" + col](r_[:, self.column_idx[col]])
for col in self.continuous_cols
if col != "cls_token"
]
return list(zip(x, x_))
......@@ -38,9 +38,8 @@ class SelfSupervisedModel(nn.Module):
self.cont_mlp_type = cont_mlp_type
self.denoise_mlps_activation = denoise_mlps_activation
self.cat_embed_input, self.column_idx = self._adjust_if_with_cls_token(
encoding_dict
)
if self.loss_type in ["denoising", "both"]:
self._t = self._tensor_to_subtract(encoding_dict)
self.projection_head1, self.projection_head2 = self._set_projection_heads()
(
......@@ -48,8 +47,6 @@ class SelfSupervisedModel(nn.Module):
self.denoise_cont_mlp,
) = self._set_cat_and_cont_denoise_mlps()
self._t = self._tensor_to_subtract(encoding_dict)
def forward(
self, X: Tensor
) -> Tuple[
......@@ -58,27 +55,42 @@ class SelfSupervisedModel(nn.Module):
Optional[Tuple[Tensor, Tensor]],
]:
_X = self._prepare_x(X)
# "uncorrupted branch"
embed = self.model._get_embeddings(X)
_embed = embed[:, 1:] if self.model.with_cls_token else embed
encoded = self.model.encoder(_embed)
if self.model.with_cls_token:
embed[:, 0, :] = 0.0
encoded = self.model.encoder(embed)
# cut_mix and mix_up branch
cut_mixed = cut_mix(X)
cut_mixed_embed = self.model._get_embeddings(cut_mixed)
_cut_mixed_embed = (
cut_mixed_embed[:, 1:] if self.model.with_cls_token else cut_mixed_embed
)
cut_mixed_embed_mixed_up = mix_up(_cut_mixed_embed)
if self.model.with_cls_token:
cut_mixed_embed[:, 0, :] = 0.0
cut_mixed_embed_mixed_up = mix_up(cut_mixed_embed)
encoded_ = self.model.encoder(cut_mixed_embed_mixed_up)
# projections for constrastive loss
if self.loss_type in ["contrastive", "both"]:
g_projs = (self.projection_head1(encoded), self.projection_head2(encoded_))
if self.model.with_cls_token:
g_projs = (
self.projection_head1(encoded[:, 1:, :]),
self.projection_head2(encoded_[:, 1:, :]),
)
else:
g_projs = (
self.projection_head1(encoded),
self.projection_head2(encoded_),
)
else:
g_projs = None
# mlps for denoising loss
if self.loss_type in ["denoising", "both"]:
_X = X - self._t.repeat(X.size(0), 1)
if self.model.with_cls_token:
_X[:, 0] = 0.0
if self.model.cat_embed_input is not None:
cat_x_and_x_ = self.denoise_cat_mlp(_X, encoded_)
else:
......@@ -123,15 +135,15 @@ class SelfSupervisedModel(nn.Module):
if self.cat_mlp_type == "single":
denoise_cat_mlp = CatSingleMlp(
self.model.input_dim,
self.cat_embed_input,
self.column_idx,
self.model.cat_embed_input,
self.model.column_idx,
self.denoise_mlps_activation,
)
elif self.cat_mlp_type == "multiple":
denoise_cat_mlp = CatFeaturesMlp(
self.model.input_dim,
self.cat_embed_input,
self.column_idx,
self.model.cat_embed_input,
self.model.column_idx,
self.denoise_mlps_activation,
)
......@@ -139,49 +151,24 @@ class SelfSupervisedModel(nn.Module):
denoise_cont_mlp = ContSingleMlp(
self.model.input_dim,
self.model.continuous_cols,
self.column_idx,
self.model.column_idx,
self.denoise_mlps_activation,
)
elif self.cont_mlp_type == "multiple":
denoise_cont_mlp = ContFeaturesMlp(
self.model.input_dim,
self.model.continuous_cols,
self.column_idx,
self.model.column_idx,
self.denoise_mlps_activation,
)
return denoise_cat_mlp, denoise_cont_mlp
def _prepare_x(self, X_tab: Tensor) -> Tensor:
_X_tab = X_tab[:, 1:] if self.model.with_cls_token else X_tab
return _X_tab - self._t.repeat(X_tab.size(0), 1)
def _adjust_if_with_cls_token(self, encoding_dict):
if self.model.with_cls_token:
adj_column_idx = {
k: self.model.column_idx[k] - 1
for k in self.model.column_idx
if k != "cls_token"
}
adj_cat_embed_input = self.model.cat_embed_input[1:]
else:
adj_column_idx = dict(column_idx)
adj_cat_embed_input = self.model.cat_embed_input
return adj_cat_embed_input, adj_column_idx
def _set_idx_to_substract(self, encoding_dict) -> Tensor:
if self.model.with_cls_token:
adj_encoding_dict = {
k: v for k, v in encoding_dict.items() if k != "cls_token"
}
if self.cat_mlp_type == "multiple":
idx_to_substract: Optional[Dict[str, Dict[str, int]]] = {
k: min(sorted(list(v.values()))) for k, v in adj_encoding_dict.items()
k: min(sorted(list(v.values()))) for k, v in encoding_dict.items()
}
if self.cat_mlp_type == "single":
......@@ -193,13 +180,15 @@ class SelfSupervisedModel(nn.Module):
self.idx_to_substract = self._set_idx_to_substract(encoding_dict)
_t = torch.zeros(len(self.column_idx))
_t = torch.zeros(len(self.model.column_idx))
if self.idx_to_substract is not None:
for colname, idx in self.idx_to_substract.items():
_t[self.column_idx[colname]] = idx
_t[self.model.column_idx[colname]] = idx
else:
for colname, _ in self.cat_embed_input:
# 0 is reserved for padding, 1 for the '[CLS]' token, if present
_t[self.column_idx[colname]] = 2 if self.model.with_cls_token else 1
_t[self.model.column_idx[colname]] = (
2 if self.model.with_cls_token else 1
)
return _t
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册