提交 9e9c4d28 编写于 作者: A Alexander Shirkov

[bugfix] unable to handle categorical column names with dots

上级 efa8793e
...@@ -128,10 +128,19 @@ class DiffSizeCatEmbeddings(nn.Module): ...@@ -128,10 +128,19 @@ class DiffSizeCatEmbeddings(nn.Module):
self.embed_input = embed_input self.embed_input = embed_input
self.use_bias = use_bias self.use_bias = use_bias
self.embed_layers_names = None
if self.embed_input is not None:
self.embed_layers_names = {
e[0]: e[0].replace(".", "_") for e in self.embed_input
}
# Categorical: val + 1 because 0 is reserved for padding/unseen cateogories. # Categorical: val + 1 because 0 is reserved for padding/unseen cateogories.
self.embed_layers = nn.ModuleDict( self.embed_layers = nn.ModuleDict(
{ {
"emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0) "emb_layer_"
+ self.embed_layers_names[col]: nn.Embedding(
val + 1, dim, padding_idx=0
)
for col, val, dim in self.embed_input for col, val, dim in self.embed_input
} }
) )
...@@ -152,7 +161,9 @@ class DiffSizeCatEmbeddings(nn.Module): ...@@ -152,7 +161,9 @@ class DiffSizeCatEmbeddings(nn.Module):
def forward(self, X: Tensor) -> Tensor: def forward(self, X: Tensor) -> Tensor:
embed = [ embed = [
self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long()) self.embed_layers["emb_layer_" + self.embed_layers_names[col]](
X[:, self.column_idx[col]].long()
)
+ ( + (
self.biases["bias_" + col].unsqueeze(0) self.biases["bias_" + col].unsqueeze(0)
if self.use_bias if self.use_bias
...@@ -186,6 +197,12 @@ class SameSizeCatEmbeddings(nn.Module): ...@@ -186,6 +197,12 @@ class SameSizeCatEmbeddings(nn.Module):
self.shared_embed = shared_embed self.shared_embed = shared_embed
self.with_cls_token = "cls_token" in column_idx self.with_cls_token = "cls_token" in column_idx
self.embed_layers_names = None
if self.embed_input is not None:
self.embed_layers_names = {
e[0]: e[0].replace(".", "_") for e in self.embed_input
}
categorical_cols = [ei[0] for ei in embed_input] categorical_cols = [ei[0] for ei in embed_input]
self.cat_idx = [self.column_idx[col] for col in categorical_cols] self.cat_idx = [self.column_idx[col] for col in categorical_cols]
...@@ -211,7 +228,7 @@ class SameSizeCatEmbeddings(nn.Module): ...@@ -211,7 +228,7 @@ class SameSizeCatEmbeddings(nn.Module):
self.embed: Union[nn.ModuleDict, nn.Embedding] = nn.ModuleDict( self.embed: Union[nn.ModuleDict, nn.Embedding] = nn.ModuleDict(
{ {
"emb_layer_" "emb_layer_"
+ col: SharedEmbeddings( + self.embed_layers_names[col]: SharedEmbeddings(
val if col == "cls_token" else val + 1, val if col == "cls_token" else val + 1,
embed_dim, embed_dim,
embed_dropout, embed_dropout,
...@@ -233,9 +250,11 @@ class SameSizeCatEmbeddings(nn.Module): ...@@ -233,9 +250,11 @@ class SameSizeCatEmbeddings(nn.Module):
def forward(self, X: Tensor) -> Tensor: def forward(self, X: Tensor) -> Tensor:
if self.shared_embed: if self.shared_embed:
cat_embed = [ cat_embed = [
self.embed["emb_layer_" + col]( # type: ignore[index] self.embed["emb_layer_" + self.embed_layers_names[col]]( # type: ignore[index]
X[:, self.column_idx[col]].long() X[:, self.column_idx[col]].long()
).unsqueeze(1) ).unsqueeze(
1
)
for col, _ in self.embed_input for col, _ in self.embed_input
] ]
x = torch.cat(cat_embed, 1) x = torch.cat(cat_embed, 1)
......
...@@ -405,6 +405,43 @@ def test_get_embeddings_deprecation_warning(): ...@@ -405,6 +405,43 @@ def test_get_embeddings_deprecation_warning():
) )
###############################################################################
# test test_handle_columns_with_dots
###############################################################################
def test_handle_columns_with_dots():
data = df.copy()
data = data.rename(columns={"col1": "col.1", "a": "a.1"})
embed_cols = [("col.1", 5), ("col2", 5)]
continuous_cols = ["col3", "col4"]
tab_preprocessor = TabPreprocessor(
cat_embed_cols=embed_cols, continuous_cols=continuous_cols
)
X_tab = tab_preprocessor.fit_transform(data)
target = data.target.values
tabmlp = TabMlp(
mlp_hidden_dims=[32, 16],
mlp_dropout=[0.5, 0.5],
column_idx={k: v for v, k in enumerate(data.columns)},
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
)
model = WideDeep(deeptabular=tabmlp)
trainer = Trainer(model, objective="binary", verbose=0)
trainer.fit(
X_tab=X_tab,
target=target,
batch_size=16,
)
preds = trainer.predict(X_tab=X_tab, batch_size=16)
assert preds.shape[0] == 32 and "train_loss" in trainer.history
############################################################################### ###############################################################################
# test Label Distribution Smoothing # test Label Distribution Smoothing
############################################################################### ###############################################################################
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册