提交 77f99678 编写于 作者: J Javier Rodriguez Zaurin

Self Supervised seems to be running in a small number of cases. Now I need to...

Self Supervised seems to be running in a small number of cases. Now I need to test a few more cases. Adjust unit test and re-organise code
上级 ad68b1b2
...@@ -50,7 +50,10 @@ if __name__ == "__main__": ...@@ -50,7 +50,10 @@ if __name__ == "__main__":
target = df[target].values target = df[target].values
tab_preprocessor = TabPreprocessor( tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols # type: ignore[arg-type] cat_embed_cols=cat_embed_cols,
continuous_cols=continuous_cols,
with_attention=True,
with_cls_token=True,
) )
X_tab = tab_preprocessor.fit_transform(df) X_tab = tab_preprocessor.fit_transform(df)
...@@ -63,5 +66,10 @@ if __name__ == "__main__": ...@@ -63,5 +66,10 @@ if __name__ == "__main__":
mlp_dropout=0.2, mlp_dropout=0.2,
) )
ss_trainer = SelfSupervisedTrainer(tab_mlp) ss_trainer = SelfSupervisedTrainer(
model=tab_mlp,
preprocessor=tab_preprocessor,
cat_mlp_type="single",
cont_mlp_type="single",
)
ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256) ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256)
...@@ -857,22 +857,47 @@ class DenoisingLoss(nn.Module): ...@@ -857,22 +857,47 @@ class DenoisingLoss(nn.Module):
def forward( def forward(
self, self,
x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]], x_cat_and_cat_: Optional[
x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]], Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
],
x_cont_and_cont_: Optional[
Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
],
) -> Tensor: ) -> Tensor:
if x_cat_and_cat_ is not None: loss_cat = (
loss_cat = torch.tensor(0.0) self._compute_cat_loss(x_cat_and_cat_)
if x_cat_and_cat_ is not None
else torch.tensor(0.0)
)
loss_cont = (
self._compute_cont_loss(x_cont_and_cont_)
if x_cont_and_cont_ is not None
else torch.tensor(0.0)
)
return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont
def _compute_cat_loss(self, x_cat_and_cat_):
loss_cat = torch.tensor(0.0)
if isinstance(x_cat_and_cat_, list):
for x, x_ in x_cat_and_cat_: for x, x_ in x_cat_and_cat_:
loss_cat += F.cross_entropy(x_, x, reduction=self.reduction) loss_cat += F.cross_entropy(x_, x, reduction=self.reduction)
else: elif isinstance(x_cat_and_cat_, tuple):
loss_cat = torch.tensor(0.0) x, x_ = x_cat_and_cat_
loss_cat += F.cross_entropy(x_, x, reduction=self.reduction)
return loss_cat
def _compute_cont_loss(self, x_cont_and_cont_):
if x_cont_and_cont_ is not None: loss_cont = torch.tensor(0.0)
loss_cont = torch.tensor(0.0) if isinstance(x_cont_and_cont_, list):
for x, x_ in x_cont_and_cont_: for x, x_ in x_cont_and_cont_:
loss_cont += F.mse_loss(x_, x, reduction=self.reduction) loss_cont += F.mse_loss(x_, x, reduction=self.reduction)
else: elif isinstance(x_cont_and_cont_, tuple):
loss_cat = torch.tensor(0.0) x, x_ = x_cont_and_cont_
loss_cont += F.mse_loss(x_, x, reduction=self.reduction)
return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont return loss_cont
...@@ -292,6 +292,14 @@ class TabPreprocessor(BasePreprocessor): ...@@ -292,6 +292,14 @@ class TabPreprocessor(BasePreprocessor):
return df.copy()[self.continuous_cols] return df.copy()[self.continuous_cols]
def _check_inputs(self): def _check_inputs(self):
if self.with_cls_token and not self.with_attention:
warnings.warn(
"If 'with_cls_token' is set to 'True', 'with_attention' will be automatically ",
"to 'True' if is 'False'",
)
self.with_attention = True
if (self.cat_embed_cols is None) and (self.continuous_cols is None): if (self.cat_embed_cols is None) and (self.continuous_cols is None):
raise ValueError( raise ValueError(
"'cat_embed_cols' and 'continuous_cols' are 'None'. Please, define at least one of the two." "'cat_embed_cols' and 'continuous_cols' are 'None'. Please, define at least one of the two."
......
...@@ -12,6 +12,7 @@ from pytorch_widedeep.callbacks import ( ...@@ -12,6 +12,7 @@ from pytorch_widedeep.callbacks import (
CallbackContainer, CallbackContainer,
LRShedulerCallback, LRShedulerCallback,
) )
from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor
from pytorch_widedeep.self_supervised_training.self_supervised_model import ( from pytorch_widedeep.self_supervised_training.self_supervised_model import (
SelfSupervisedModel, SelfSupervisedModel,
) )
...@@ -21,6 +22,7 @@ class BaseSelfSupervisedTrainer(ABC): ...@@ -21,6 +22,7 @@ class BaseSelfSupervisedTrainer(ABC):
def __init__( def __init__(
self, self,
model, model,
preprocessor: TabPreprocessor,
optimizer: Optional[Optimizer], optimizer: Optional[Optimizer],
lr_scheduler: Optional[LRScheduler], lr_scheduler: Optional[LRScheduler],
callbacks: Optional[List[Callback]], callbacks: Optional[List[Callback]],
...@@ -38,6 +40,7 @@ class BaseSelfSupervisedTrainer(ABC): ...@@ -38,6 +40,7 @@ class BaseSelfSupervisedTrainer(ABC):
self.ss_model = SelfSupervisedModel( self.ss_model = SelfSupervisedModel(
model, model,
preprocessor.label_encoder.encoding_dict,
loss_type, loss_type,
projection_head1_dims, projection_head1_dims,
projection_head2_dims, projection_head2_dims,
......
...@@ -6,23 +6,18 @@ from pytorch_widedeep.models.tabular.mlp._layers import MLP ...@@ -6,23 +6,18 @@ from pytorch_widedeep.models.tabular.mlp._layers import MLP
class CatSingleMlp(nn.Module): class CatSingleMlp(nn.Module):
def __init__(self, model, activation): def __init__(self, input_dim, cat_embed_input, column_idx, activation):
super(CatSingleMlp, self).__init__() super(CatSingleMlp, self).__init__()
self.column_idx = model.column_idx self.input_dim = input_dim
self.cat_embed_input = model.cat_embed_input self.column_idx = column_idx
self.num_class = model.cat_and_cont_embed.cat_embed.n_tokens self.cat_embed_input = cat_embed_input
self.activation = activation self.activation = activation
mlp_hidden_dims = [ self.num_class = sum([ei[1] for ei in cat_embed_input])
model.input_dim,
self.num_class * 4,
self.num_class,
]
self.mlp = MLP( self.mlp = MLP(
d_hidden=mlp_hidden_dims, d_hidden=[input_dim, self.num_class * 4, self.num_class],
activation=activation, activation=activation,
dropout=0.0, dropout=0.0,
batchnorm=False, batchnorm=False,
...@@ -32,30 +27,26 @@ class CatSingleMlp(nn.Module): ...@@ -32,30 +27,26 @@ class CatSingleMlp(nn.Module):
def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]:
# '-1' because the Label Encoder is designed to leave 0 for unseen x = torch.cat(
# categories (or padding) during the supervised training. Here we [X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input]
# are predicting directly the categories, and a categorical target
# has to start from 0
x = torch.stack(
[
X[:, self.column_idx[col]].long() - 1
for col, _, _ in self.cat_embed_input
],
dim=1,
) )
x_ = self.mlp(r_) cat_r_ = torch.cat(
[r_[:, self.column_idx[col], :] for col, _ in self.cat_embed_input]
)
x_ = self.mlp(cat_r_)
return x, x_ return x, x_
class CatFeaturesMlp(nn.Module): class CatFeaturesMlp(nn.Module):
def __init__(self, model, activation): def __init__(self, input_dim, cat_embed_input, column_idx, activation):
super(CatFeaturesMlp, self).__init__() super(CatFeaturesMlp, self).__init__()
self.column_idx = model.column_idx self.input_dim = input_dim
self.cat_embed_input = model.cat_embed_input self.column_idx = column_idx
self.cat_embed_input = cat_embed_input
self.activation = activation self.activation = activation
self.mlp = nn.ModuleDict( self.mlp = nn.ModuleDict(
...@@ -63,7 +54,7 @@ class CatFeaturesMlp(nn.Module): ...@@ -63,7 +54,7 @@ class CatFeaturesMlp(nn.Module):
"mlp_" "mlp_"
+ col: MLP( + col: MLP(
d_hidden=[ d_hidden=[
model.input_dim, input_dim,
val * 4, val * 4,
val, val,
], ],
...@@ -73,43 +64,33 @@ class CatFeaturesMlp(nn.Module): ...@@ -73,43 +64,33 @@ class CatFeaturesMlp(nn.Module):
batchnorm_last=False, batchnorm_last=False,
linear_first=False, linear_first=False,
) )
for col, val, _ in model.cat_embed_input for col, val in self.cat_embed_input
} }
) )
def forward(self, X: Tensor, r_: Tensor) -> List[Tuple[Tensor, Tensor]]: def forward(self, X: Tensor, r_: Tensor) -> List[Tuple[Tensor, Tensor]]:
# '-1' because the Label Encoder is designed to leave 0 for unseen x = [X[:, self.column_idx[col]].long() for col, _ in self.cat_embed_input]
# categories (or padding) during the supervised training. Here we
# are predicting directly the categories, and a categorical target
# has to start from 0
x = [
X[:, self.column_idx[col]].long() - 1 for col, _, _ in self.cat_embed_input
]
x_ = [ x_ = [
self.mlp["mlp_" + col](r_[:, self.column_idx[col], :]) self.mlp["mlp_" + col](r_[:, self.column_idx[col], :])
for col, _, _ in self.cat_embed_input for col, _ in self.cat_embed_input
] ]
return list(zip(x, x_)) return list(zip(x, x_))
class ContSingleMlp(nn.Module): class ContSingleMlp(nn.Module):
def __init__(self, model, activation): def __init__(self, input_dim, continuous_cols, column_idx, activation):
super(ContSingleMlp, self).__init__() super(ContSingleMlp, self).__init__()
self.column_idx = model.column_idx self.input_dim = input_dim
self.continuous_cols = model.continuous_cols self.column_idx = column_idx
self.continuous_cols = continuous_cols
self.activation = activation self.activation = activation
self.mlp = MLP( self.mlp = MLP(
d_hidden=[ d_hidden=[input_dim, input_dim * 2, 1],
model.input_dim,
model.input_dim * 2,
1,
],
activation=activation, activation=activation,
dropout=0.0, dropout=0.0,
batchnorm=False, batchnorm=False,
...@@ -119,23 +100,26 @@ class ContSingleMlp(nn.Module): ...@@ -119,23 +100,26 @@ class ContSingleMlp(nn.Module):
def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.stack( x = torch.cat(
[X[:, self.column_idx[col]] for col in self.continuous_cols], [X[:, self.column_idx[col]].float() for col in self.continuous_cols]
dim=1, ).unsqueeze(1)
cont_r_ = torch.cat(
[r_[:, self.column_idx[col], :] for col in self.continuous_cols]
) )
x_ = self.mlp(r_) x_ = self.mlp(cont_r_)
return x, x_ return x, x_
class ContFeaturesMlp(nn.Module): class ContFeaturesMlp(nn.Module):
def __init__(self, model, activation): def __init__(self, input_dim, continuous_cols, column_idx, activation):
super(ContFeaturesMlp, self).__init__() super(ContFeaturesMlp, self).__init__()
self.column_idx = model.column_idx self.input_dim = input_dim
self.continuous_cols = model.continuous_cols self.column_idx = column_idx
self.continuous_cols = continuous_cols
self.activation = activation self.activation = activation
self.mlp = nn.ModuleDict( self.mlp = nn.ModuleDict(
...@@ -143,8 +127,8 @@ class ContFeaturesMlp(nn.Module): ...@@ -143,8 +127,8 @@ class ContFeaturesMlp(nn.Module):
"mlp_" "mlp_"
+ col: MLP( + col: MLP(
d_hidden=[ d_hidden=[
model.input_dim, input_dim,
model.input_dim * 2, input_dim * 2,
1, 1,
], ],
activation=activation, activation=activation,
...@@ -153,7 +137,7 @@ class ContFeaturesMlp(nn.Module): ...@@ -153,7 +137,7 @@ class ContFeaturesMlp(nn.Module):
batchnorm_last=False, batchnorm_last=False,
linear_first=False, linear_first=False,
) )
for col in model.continuous_cols for col in self.continuous_cols
} }
) )
......
...@@ -18,6 +18,7 @@ class SelfSupervisedModel(nn.Module): ...@@ -18,6 +18,7 @@ class SelfSupervisedModel(nn.Module):
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
encoding_dict: Dict[str, Dict[str, int]],
loss_type: Literal["contrastive", "denoising", "both"], loss_type: Literal["contrastive", "denoising", "both"],
projection_head1_dims: Optional[List], projection_head1_dims: Optional[List],
projection_head2_dims: Optional[List], projection_head2_dims: Optional[List],
...@@ -29,7 +30,6 @@ class SelfSupervisedModel(nn.Module): ...@@ -29,7 +30,6 @@ class SelfSupervisedModel(nn.Module):
super(SelfSupervisedModel, self).__init__() super(SelfSupervisedModel, self).__init__()
self.model = model self.model = model
self.loss_type = loss_type self.loss_type = loss_type
self.projection_head1_dims = projection_head1_dims self.projection_head1_dims = projection_head1_dims
self.projection_head2_dims = projection_head2_dims self.projection_head2_dims = projection_head2_dims
...@@ -38,12 +38,18 @@ class SelfSupervisedModel(nn.Module): ...@@ -38,12 +38,18 @@ class SelfSupervisedModel(nn.Module):
self.cont_mlp_type = cont_mlp_type self.cont_mlp_type = cont_mlp_type
self.denoise_mlps_activation = denoise_mlps_activation self.denoise_mlps_activation = denoise_mlps_activation
self.cat_embed_input, self.column_idx = self._adjust_if_with_cls_token(
encoding_dict
)
self.projection_head1, self.projection_head2 = self._set_projection_heads() self.projection_head1, self.projection_head2 = self._set_projection_heads()
( (
self.denoise_cat_mlp, self.denoise_cat_mlp,
self.denoise_cont_mlp, self.denoise_cont_mlp,
) = self._set_cat_and_cont_denoise_mlps() ) = self._set_cat_and_cont_denoise_mlps()
self._t = self._tensor_to_subtract(encoding_dict)
def forward( def forward(
self, X: Tensor self, X: Tensor
) -> Tuple[ ) -> Tuple[
...@@ -52,11 +58,19 @@ class SelfSupervisedModel(nn.Module): ...@@ -52,11 +58,19 @@ class SelfSupervisedModel(nn.Module):
Optional[Tuple[Tensor, Tensor]], Optional[Tuple[Tensor, Tensor]],
]: ]:
encoded = self.model.encoder(self.model._get_embeddings(X)) _X = self._prepare_x(X)
embed = self.model._get_embeddings(X)
_embed = embed[:, 1:] if self.model.with_cls_token else embed
encoded = self.model.encoder(_embed)
cut_mixed = cut_mix(X) cut_mixed = cut_mix(X)
cut_mixed_embed = self.model._get_embeddings(cut_mixed) cut_mixed_embed = self.model._get_embeddings(cut_mixed)
cut_mixed_embed_mixed_up = mix_up(cut_mixed_embed) _cut_mixed_embed = (
cut_mixed_embed[:, 1:] if self.model.with_cls_token else cut_mixed_embed
)
cut_mixed_embed_mixed_up = mix_up(_cut_mixed_embed)
encoded_ = self.model.encoder(cut_mixed_embed_mixed_up) encoded_ = self.model.encoder(cut_mixed_embed_mixed_up)
if self.loss_type in ["contrastive", "both"]: if self.loss_type in ["contrastive", "both"]:
...@@ -66,11 +80,11 @@ class SelfSupervisedModel(nn.Module): ...@@ -66,11 +80,11 @@ class SelfSupervisedModel(nn.Module):
if self.loss_type in ["denoising", "both"]: if self.loss_type in ["denoising", "both"]:
if self.model.cat_embed_input is not None: if self.model.cat_embed_input is not None:
cat_x_and_x_ = self.denoise_cat_mlp(X, encoded_) cat_x_and_x_ = self.denoise_cat_mlp(_X, encoded_)
else: else:
cat_x_and_x_ = None cat_x_and_x_ = None
if self.model.continuous_cols is not None: if self.model.continuous_cols is not None:
cont_x_and_x_ = self.denoise_cont_mlp(X, encoded_) cont_x_and_x_ = self.denoise_cont_mlp(_X, encoded_)
else: else:
cont_x_and_x_ = None cont_x_and_x_ = None
...@@ -107,13 +121,85 @@ class SelfSupervisedModel(nn.Module): ...@@ -107,13 +121,85 @@ class SelfSupervisedModel(nn.Module):
def _set_cat_and_cont_denoise_mlps(self) -> Tuple[nn.Module, nn.Module]: def _set_cat_and_cont_denoise_mlps(self) -> Tuple[nn.Module, nn.Module]:
if self.cat_mlp_type == "single": if self.cat_mlp_type == "single":
denoise_cat_mlp = CatSingleMlp(self.model, self.denoise_mlps_activation) denoise_cat_mlp = CatSingleMlp(
self.model.input_dim,
self.cat_embed_input,
self.column_idx,
self.denoise_mlps_activation,
)
elif self.cat_mlp_type == "multiple": elif self.cat_mlp_type == "multiple":
denoise_cat_mlp = CatFeaturesMlp(self.model, self.denoise_mlps_activation) denoise_cat_mlp = CatFeaturesMlp(
self.model.input_dim,
self.cat_embed_input,
self.column_idx,
self.denoise_mlps_activation,
)
if self.cont_mlp_type == "single": if self.cont_mlp_type == "single":
denoise_cont_mlp = ContSingleMlp(self.model, self.denoise_mlps_activation) denoise_cont_mlp = ContSingleMlp(
self.model.input_dim,
self.model.continuous_cols,
self.column_idx,
self.denoise_mlps_activation,
)
elif self.cont_mlp_type == "multiple": elif self.cont_mlp_type == "multiple":
denoise_cont_mlp = ContFeaturesMlp(self.model, self.denoise_mlps_activation) denoise_cont_mlp = ContFeaturesMlp(
self.model.input_dim,
self.model.continuous_cols,
self.column_idx,
self.denoise_mlps_activation,
)
return denoise_cat_mlp, denoise_cont_mlp return denoise_cat_mlp, denoise_cont_mlp
def _prepare_x(self, X_tab: Tensor) -> Tensor:
_X_tab = X_tab[:, 1:] if self.model.with_cls_token else X_tab
return _X_tab - self._t.repeat(X_tab.size(0), 1)
def _adjust_if_with_cls_token(self, encoding_dict):
if self.model.with_cls_token:
adj_column_idx = {
k: self.model.column_idx[k] - 1
for k in self.model.column_idx
if k != "cls_token"
}
adj_cat_embed_input = self.model.cat_embed_input[1:]
else:
adj_column_idx = dict(column_idx)
adj_cat_embed_input = self.model.cat_embed_input
return adj_cat_embed_input, adj_column_idx
def _set_idx_to_substract(self, encoding_dict) -> Tensor:
if self.model.with_cls_token:
adj_encoding_dict = {
k: v for k, v in encoding_dict.items() if k != "cls_token"
}
if self.cat_mlp_type == "multiple":
idx_to_substract: Optional[Dict[str, Dict[str, int]]] = {
k: min(sorted(list(v.values()))) for k, v in adj_encoding_dict.items()
}
if self.cat_mlp_type == "single":
idx_to_substract = None
return idx_to_substract
def _tensor_to_subtract(self, encoding_dict) -> Tensor:
self.idx_to_substract = self._set_idx_to_substract(encoding_dict)
_t = torch.zeros(len(self.column_idx))
if self.idx_to_substract is not None:
for colname, idx in self.idx_to_substract.items():
_t[self.column_idx[colname]] = idx
else:
for colname, _ in self.cat_embed_input:
# 0 is reserved for padding, 1 for the '[CLS]' token, if present
_t[self.column_idx[colname]] = 2 if self.model.with_cls_token else 1
return _t
...@@ -18,6 +18,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer): ...@@ -18,6 +18,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
def __init__( def __init__(
self, self,
model, model,
preprocessor,
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
lr_scheduler: Optional[LRScheduler] = None, lr_scheduler: Optional[LRScheduler] = None,
callbacks: Optional[List[Callback]] = None, callbacks: Optional[List[Callback]] = None,
...@@ -34,6 +35,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer): ...@@ -34,6 +35,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
): ):
super().__init__( super().__init__(
model=model, model=model,
preprocessor=preprocessor,
loss_type=loss_type, loss_type=loss_type,
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册