From 9e9c4d2817e098f1ee4c7239bda43910e82fbe4c Mon Sep 17 00:00:00 2001 From: Alexander Shirkov Date: Wed, 6 Apr 2022 12:16:58 -0700 Subject: [PATCH] [bugfix] unable to handle categorical column names with dots --- .../models/tabular/embeddings_layers.py | 29 ++++++++++++--- .../test_miscellaneous.py | 37 +++++++++++++++++++ 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/pytorch_widedeep/models/tabular/embeddings_layers.py b/pytorch_widedeep/models/tabular/embeddings_layers.py index e167936..21294a2 100644 --- a/pytorch_widedeep/models/tabular/embeddings_layers.py +++ b/pytorch_widedeep/models/tabular/embeddings_layers.py @@ -128,10 +128,19 @@ class DiffSizeCatEmbeddings(nn.Module): self.embed_input = embed_input self.use_bias = use_bias + self.embed_layers_names = None + if self.embed_input is not None: + self.embed_layers_names = { + e[0]: e[0].replace(".", "_") for e in self.embed_input + } + # Categorical: val + 1 because 0 is reserved for padding/unseen cateogories. self.embed_layers = nn.ModuleDict( { - "emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0) + "emb_layer_" + + self.embed_layers_names[col]: nn.Embedding( + val + 1, dim, padding_idx=0 + ) for col, val, dim in self.embed_input } ) @@ -152,7 +161,9 @@ class DiffSizeCatEmbeddings(nn.Module): def forward(self, X: Tensor) -> Tensor: embed = [ - self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long()) + self.embed_layers["emb_layer_" + self.embed_layers_names[col]]( + X[:, self.column_idx[col]].long() + ) + ( self.biases["bias_" + col].unsqueeze(0) if self.use_bias @@ -186,6 +197,12 @@ class SameSizeCatEmbeddings(nn.Module): self.shared_embed = shared_embed self.with_cls_token = "cls_token" in column_idx + self.embed_layers_names = None + if self.embed_input is not None: + self.embed_layers_names = { + e[0]: e[0].replace(".", "_") for e in self.embed_input + } + categorical_cols = [ei[0] for ei in embed_input] self.cat_idx = [self.column_idx[col] for col in categorical_cols] @@ -211,7 +228,7 @@ class SameSizeCatEmbeddings(nn.Module): self.embed: Union[nn.ModuleDict, nn.Embedding] = nn.ModuleDict( { "emb_layer_" - + col: SharedEmbeddings( + + self.embed_layers_names[col]: SharedEmbeddings( val if col == "cls_token" else val + 1, embed_dim, embed_dropout, @@ -233,9 +250,11 @@ class SameSizeCatEmbeddings(nn.Module): def forward(self, X: Tensor) -> Tensor: if self.shared_embed: cat_embed = [ - self.embed["emb_layer_" + col]( # type: ignore[index] + self.embed["emb_layer_" + self.embed_layers_names[col]]( # type: ignore[index] X[:, self.column_idx[col]].long() - ).unsqueeze(1) + ).unsqueeze( + 1 + ) for col, _ in self.embed_input ] x = torch.cat(cat_embed, 1) diff --git a/tests/test_model_functioning/test_miscellaneous.py b/tests/test_model_functioning/test_miscellaneous.py index 8f27095..64ef44b 100644 --- a/tests/test_model_functioning/test_miscellaneous.py +++ b/tests/test_model_functioning/test_miscellaneous.py @@ -405,6 +405,43 @@ def test_get_embeddings_deprecation_warning(): ) +############################################################################### +# test test_handle_columns_with_dots +############################################################################### + + +def test_handle_columns_with_dots(): + data = df.copy() + data = data.rename(columns={"col1": "col.1", "a": "a.1"}) + + embed_cols = [("col.1", 5), ("col2", 5)] + continuous_cols = ["col3", "col4"] + + tab_preprocessor = TabPreprocessor( + cat_embed_cols=embed_cols, continuous_cols=continuous_cols + ) + X_tab = tab_preprocessor.fit_transform(data) + target = data.target.values + + tabmlp = TabMlp( + mlp_hidden_dims=[32, 16], + mlp_dropout=[0.5, 0.5], + column_idx={k: v for v, k in enumerate(data.columns)}, + cat_embed_input=tab_preprocessor.cat_embed_input, + continuous_cols=tab_preprocessor.continuous_cols, + ) + + model = WideDeep(deeptabular=tabmlp) + trainer = Trainer(model, objective="binary", verbose=0) + trainer.fit( + X_tab=X_tab, + target=target, + batch_size=16, + ) + preds = trainer.predict(X_tab=X_tab, batch_size=16) + assert preds.shape[0] == 32 and "train_loss" in trainer.history + + ############################################################################### # test Label Distribution Smoothing ############################################################################### -- GitLab