提交 77f99678 编写于 作者: J Javier Rodriguez Zaurin

Self Supervised seems to be running in a small number of cases. Now I need to...

Self Supervised seems to be running in a small number of cases. Now I need to test a few more cases. Adjust unit test and re-organise code
上级 ad68b1b2
......@@ -50,7 +50,10 @@ if __name__ == "__main__":
target = df[target].values
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols # type: ignore[arg-type]
cat_embed_cols=cat_embed_cols,
continuous_cols=continuous_cols,
with_attention=True,
with_cls_token=True,
)
X_tab = tab_preprocessor.fit_transform(df)
......@@ -63,5 +66,10 @@ if __name__ == "__main__":
mlp_dropout=0.2,
)
ss_trainer = SelfSupervisedTrainer(tab_mlp)
ss_trainer = SelfSupervisedTrainer(
model=tab_mlp,
preprocessor=tab_preprocessor,
cat_mlp_type="single",
cont_mlp_type="single",
)
ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256)
......@@ -857,22 +857,47 @@ class DenoisingLoss(nn.Module):
def forward(
self,
x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]],
x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]],
x_cat_and_cat_: Optional[
Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
],
x_cont_and_cont_: Optional[
Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
],
) -> Tensor:
if x_cat_and_cat_ is not None:
loss_cat = torch.tensor(0.0)
loss_cat = (
self._compute_cat_loss(x_cat_and_cat_)
if x_cat_and_cat_ is not None
else torch.tensor(0.0)
)
loss_cont = (
self._compute_cont_loss(x_cont_and_cont_)
if x_cont_and_cont_ is not None
else torch.tensor(0.0)
)
return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont
def _compute_cat_loss(self, x_cat_and_cat_):
loss_cat = torch.tensor(0.0)
if isinstance(x_cat_and_cat_, list):
for x, x_ in x_cat_and_cat_:
loss_cat += F.cross_entropy(x_, x, reduction=self.reduction)
else:
loss_cat = torch.tensor(0.0)
elif isinstance(x_cat_and_cat_, tuple):
x, x_ = x_cat_and_cat_
loss_cat += F.cross_entropy(x_, x, reduction=self.reduction)
return loss_cat
def _compute_cont_loss(self, x_cont_and_cont_):
if x_cont_and_cont_ is not None:
loss_cont = torch.tensor(0.0)
loss_cont = torch.tensor(0.0)
if isinstance(x_cont_and_cont_, list):
for x, x_ in x_cont_and_cont_:
loss_cont += F.mse_loss(x_, x, reduction=self.reduction)
else:
loss_cat = torch.tensor(0.0)
elif isinstance(x_cont_and_cont_, tuple):
x, x_ = x_cont_and_cont_
loss_cont += F.mse_loss(x_, x, reduction=self.reduction)
return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont
return loss_cont
......@@ -292,6 +292,14 @@ class TabPreprocessor(BasePreprocessor):
return df.copy()[self.continuous_cols]
def _check_inputs(self):
if self.with_cls_token and not self.with_attention:
warnings.warn(
"If 'with_cls_token' is set to 'True', 'with_attention' will be automatically ",
"to 'True' if is 'False'",
)
self.with_attention = True
if (self.cat_embed_cols is None) and (self.continuous_cols is None):
raise ValueError(
"'cat_embed_cols' and 'continuous_cols' are 'None'. Please, define at least one of the two."
......
......@@ -12,6 +12,7 @@ from pytorch_widedeep.callbacks import (
CallbackContainer,
LRShedulerCallback,
)
from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor
from pytorch_widedeep.self_supervised_training.self_supervised_model import (
SelfSupervisedModel,
)
......@@ -21,6 +22,7 @@ class BaseSelfSupervisedTrainer(ABC):
def __init__(
self,
model,
preprocessor: TabPreprocessor,
optimizer: Optional[Optimizer],
lr_scheduler: Optional[LRScheduler],
callbacks: Optional[List[Callback]],
......@@ -38,6 +40,7 @@ class BaseSelfSupervisedTrainer(ABC):
self.ss_model = SelfSupervisedModel(
model,
preprocessor.label_encoder.encoding_dict,
loss_type,
projection_head1_dims,
projection_head2_dims,
......
......@@ -6,23 +6,18 @@ from pytorch_widedeep.models.tabular.mlp._layers import MLP
class CatSingleMlp(nn.Module):
def __init__(self, model, activation):
def __init__(self, input_dim, cat_embed_input, column_idx, activation):
super(CatSingleMlp, self).__init__()
self.column_idx = model.column_idx
self.cat_embed_input = model.cat_embed_input
self.num_class = model.cat_and_cont_embed.cat_embed.n_tokens
self.input_dim = input_dim
self.column_idx = column_idx
self.cat_embed_input = cat_embed_input
self.activation = activation
mlp_hidden_dims = [
model.input_dim,
self.num_class * 4,
self.num_class,
]
self.num_class = sum([ei[1] for ei in cat_embed_input])
self.mlp = MLP(
d_hidden=mlp_hidden_dims,
d_hidden=[input_dim, self.num_class * 4, self.num_class],
activation=activation,
dropout=0.0,
batchnorm=False,
......@@ -32,30 +27,26 @@ class CatSingleMlp(nn.Module):
def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]:
# '-1' because the Label Encoder is designed to leave 0 for unseen
# categories (or padding) during the supervised training. Here we
# are predicting directly the categories, and a categorical target
# has to start from 0
x = torch.stack(
[
X[:, self.column_idx[col]].long() - 1
for col, _, _ in self.cat_embed_input
],
dim=1,
x = torch.cat(
[X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input]
)
x_ = self.mlp(r_)
cat_r_ = torch.cat(
[r_[:, self.column_idx[col], :] for col, _ in self.cat_embed_input]
)
x_ = self.mlp(cat_r_)
return x, x_
class CatFeaturesMlp(nn.Module):
def __init__(self, model, activation):
def __init__(self, input_dim, cat_embed_input, column_idx, activation):
super(CatFeaturesMlp, self).__init__()
self.column_idx = model.column_idx
self.cat_embed_input = model.cat_embed_input
self.input_dim = input_dim
self.column_idx = column_idx
self.cat_embed_input = cat_embed_input
self.activation = activation
self.mlp = nn.ModuleDict(
......@@ -63,7 +54,7 @@ class CatFeaturesMlp(nn.Module):
"mlp_"
+ col: MLP(
d_hidden=[
model.input_dim,
input_dim,
val * 4,
val,
],
......@@ -73,43 +64,33 @@ class CatFeaturesMlp(nn.Module):
batchnorm_last=False,
linear_first=False,
)
for col, val, _ in model.cat_embed_input
for col, val in self.cat_embed_input
}
)
def forward(self, X: Tensor, r_: Tensor) -> List[Tuple[Tensor, Tensor]]:
# '-1' because the Label Encoder is designed to leave 0 for unseen
# categories (or padding) during the supervised training. Here we
# are predicting directly the categories, and a categorical target
# has to start from 0
x = [
X[:, self.column_idx[col]].long() - 1 for col, _, _ in self.cat_embed_input
]
x = [X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input]
x_ = [
self.mlp["mlp_" + col](r_[:, self.column_idx[col], :])
for col, _, _ in self.cat_embed_input
for col, _ in self.cat_embed_input
]
return list(zip(x, x_))
class ContSingleMlp(nn.Module):
def __init__(self, model, activation):
def __init__(self, input_dim, continuous_cols, column_idx, activation):
super(ContSingleMlp, self).__init__()
self.column_idx = model.column_idx
self.continuous_cols = model.continuous_cols
self.input_dim = input_dim
self.column_idx = column_idx
self.continuous_cols = continuous_cols
self.activation = activation
self.mlp = MLP(
d_hidden=[
model.input_dim,
model.input_dim * 2,
1,
],
d_hidden=[input_dim, input_dim * 2, 1],
activation=activation,
dropout=0.0,
batchnorm=False,
......@@ -119,23 +100,26 @@ class ContSingleMlp(nn.Module):
def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.stack(
[X[:, self.column_idx[col]] for col in self.continuous_cols],
dim=1,
x = torch.cat(
[X[:, self.column_idx[col]].float() for col in self.continuous_cols]
).unsqueeze(1)
cont_r_ = torch.cat(
[r_[:, self.column_idx[col], :] for col in self.continuous_cols]
)
x_ = self.mlp(r_)
x_ = self.mlp(cont_r_)
return x, x_
class ContFeaturesMlp(nn.Module):
def __init__(self, model, activation):
def __init__(self, input_dim, continuous_cols, column_idx, activation):
super(ContFeaturesMlp, self).__init__()
self.column_idx = model.column_idx
self.continuous_cols = model.continuous_cols
self.input_dim = input_dim
self.column_idx = column_idx
self.continuous_cols = continuous_cols
self.activation = activation
self.mlp = nn.ModuleDict(
......@@ -143,8 +127,8 @@ class ContFeaturesMlp(nn.Module):
"mlp_"
+ col: MLP(
d_hidden=[
model.input_dim,
model.input_dim * 2,
input_dim,
input_dim * 2,
1,
],
activation=activation,
......@@ -153,7 +137,7 @@ class ContFeaturesMlp(nn.Module):
batchnorm_last=False,
linear_first=False,
)
for col in model.continuous_cols
for col in self.continuous_cols
}
)
......
......@@ -18,6 +18,7 @@ class SelfSupervisedModel(nn.Module):
def __init__(
self,
model: nn.Module,
encoding_dict: Dict[str, Dict[str, int]],
loss_type: Literal["contrastive", "denoising", "both"],
projection_head1_dims: Optional[List],
projection_head2_dims: Optional[List],
......@@ -29,7 +30,6 @@ class SelfSupervisedModel(nn.Module):
super(SelfSupervisedModel, self).__init__()
self.model = model
self.loss_type = loss_type
self.projection_head1_dims = projection_head1_dims
self.projection_head2_dims = projection_head2_dims
......@@ -38,12 +38,18 @@ class SelfSupervisedModel(nn.Module):
self.cont_mlp_type = cont_mlp_type
self.denoise_mlps_activation = denoise_mlps_activation
self.cat_embed_input, self.column_idx = self._adjust_if_with_cls_token(
encoding_dict
)
self.projection_head1, self.projection_head2 = self._set_projection_heads()
(
self.denoise_cat_mlp,
self.denoise_cont_mlp,
) = self._set_cat_and_cont_denoise_mlps()
self._t = self._tensor_to_subtract(encoding_dict)
def forward(
self, X: Tensor
) -> Tuple[
......@@ -52,11 +58,19 @@ class SelfSupervisedModel(nn.Module):
Optional[Tuple[Tensor, Tensor]],
]:
encoded = self.model.encoder(self.model._get_embeddings(X))
_X = self._prepare_x(X)
embed = self.model._get_embeddings(X)
_embed = embed[:, 1:] if self.model.with_cls_token else embed
encoded = self.model.encoder(_embed)
cut_mixed = cut_mix(X)
cut_mixed_embed = self.model._get_embeddings(cut_mixed)
cut_mixed_embed_mixed_up = mix_up(cut_mixed_embed)
_cut_mixed_embed = (
cut_mixed_embed[:, 1:] if self.model.with_cls_token else cut_mixed_embed
)
cut_mixed_embed_mixed_up = mix_up(_cut_mixed_embed)
encoded_ = self.model.encoder(cut_mixed_embed_mixed_up)
if self.loss_type in ["contrastive", "both"]:
......@@ -66,11 +80,11 @@ class SelfSupervisedModel(nn.Module):
if self.loss_type in ["denoising", "both"]:
if self.model.cat_embed_input is not None:
cat_x_and_x_ = self.denoise_cat_mlp(X, encoded_)
cat_x_and_x_ = self.denoise_cat_mlp(_X, encoded_)
else:
cat_x_and_x_ = None
if self.model.continuous_cols is not None:
cont_x_and_x_ = self.denoise_cont_mlp(X, encoded_)
cont_x_and_x_ = self.denoise_cont_mlp(_X, encoded_)
else:
cont_x_and_x_ = None
......@@ -107,13 +121,85 @@ class SelfSupervisedModel(nn.Module):
def _set_cat_and_cont_denoise_mlps(self) -> Tuple[nn.Module, nn.Module]:
if self.cat_mlp_type == "single":
denoise_cat_mlp = CatSingleMlp(self.model, self.denoise_mlps_activation)
denoise_cat_mlp = CatSingleMlp(
self.model.input_dim,
self.cat_embed_input,
self.column_idx,
self.denoise_mlps_activation,
)
elif self.cat_mlp_type == "multiple":
denoise_cat_mlp = CatFeaturesMlp(self.model, self.denoise_mlps_activation)
denoise_cat_mlp = CatFeaturesMlp(
self.model.input_dim,
self.cat_embed_input,
self.column_idx,
self.denoise_mlps_activation,
)
if self.cont_mlp_type == "single":
denoise_cont_mlp = ContSingleMlp(self.model, self.denoise_mlps_activation)
denoise_cont_mlp = ContSingleMlp(
self.model.input_dim,
self.model.continuous_cols,
self.column_idx,
self.denoise_mlps_activation,
)
elif self.cont_mlp_type == "multiple":
denoise_cont_mlp = ContFeaturesMlp(self.model, self.denoise_mlps_activation)
denoise_cont_mlp = ContFeaturesMlp(
self.model.input_dim,
self.model.continuous_cols,
self.column_idx,
self.denoise_mlps_activation,
)
return denoise_cat_mlp, denoise_cont_mlp
def _prepare_x(self, X_tab: Tensor) -> Tensor:
_X_tab = X_tab[:, 1:] if self.model.with_cls_token else X_tab
return _X_tab - self._t.repeat(X_tab.size(0), 1)
def _adjust_if_with_cls_token(self, encoding_dict):
if self.model.with_cls_token:
adj_column_idx = {
k: self.model.column_idx[k] - 1
for k in self.model.column_idx
if k != "cls_token"
}
adj_cat_embed_input = self.model.cat_embed_input[1:]
else:
adj_column_idx = dict(column_idx)
adj_cat_embed_input = self.model.cat_embed_input
return adj_cat_embed_input, adj_column_idx
def _set_idx_to_substract(self, encoding_dict) -> Tensor:
if self.model.with_cls_token:
adj_encoding_dict = {
k: v for k, v in encoding_dict.items() if k != "cls_token"
}
if self.cat_mlp_type == "multiple":
idx_to_substract: Optional[Dict[str, Dict[str, int]]] = {
k: min(sorted(list(v.values()))) for k, v in adj_encoding_dict.items()
}
if self.cat_mlp_type == "single":
idx_to_substract = None
return idx_to_substract
def _tensor_to_subtract(self, encoding_dict) -> Tensor:
self.idx_to_substract = self._set_idx_to_substract(encoding_dict)
_t = torch.zeros(len(self.column_idx))
if self.idx_to_substract is not None:
for colname, idx in self.idx_to_substract.items():
_t[self.column_idx[colname]] = idx
else:
for colname, _ in self.cat_embed_input:
# 0 is reserved for padding, 1 for the '[CLS]' token, if present
_t[self.column_idx[colname]] = 2 if self.model.with_cls_token else 1
return _t
......@@ -18,6 +18,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
def __init__(
self,
model,
preprocessor,
optimizer: Optional[Optimizer] = None,
lr_scheduler: Optional[LRScheduler] = None,
callbacks: Optional[List[Callback]] = None,
......@@ -34,6 +35,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
):
super().__init__(
model=model,
preprocessor=preprocessor,
loss_type=loss_type,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册