提交 9e9c4d28 编写于 作者: A Alexander Shirkov

[bugfix] unable to handle categorical column names with dots

上级 efa8793e
......@@ -128,10 +128,19 @@ class DiffSizeCatEmbeddings(nn.Module):
self.embed_input = embed_input
self.use_bias = use_bias
self.embed_layers_names = None
if self.embed_input is not None:
self.embed_layers_names = {
e[0]: e[0].replace(".", "_") for e in self.embed_input
}
# Categorical: val + 1 because 0 is reserved for padding/unseen cateogories.
self.embed_layers = nn.ModuleDict(
{
"emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0)
"emb_layer_"
+ self.embed_layers_names[col]: nn.Embedding(
val + 1, dim, padding_idx=0
)
for col, val, dim in self.embed_input
}
)
......@@ -152,7 +161,9 @@ class DiffSizeCatEmbeddings(nn.Module):
def forward(self, X: Tensor) -> Tensor:
embed = [
self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long())
self.embed_layers["emb_layer_" + self.embed_layers_names[col]](
X[:, self.column_idx[col]].long()
)
+ (
self.biases["bias_" + col].unsqueeze(0)
if self.use_bias
......@@ -186,6 +197,12 @@ class SameSizeCatEmbeddings(nn.Module):
self.shared_embed = shared_embed
self.with_cls_token = "cls_token" in column_idx
self.embed_layers_names = None
if self.embed_input is not None:
self.embed_layers_names = {
e[0]: e[0].replace(".", "_") for e in self.embed_input
}
categorical_cols = [ei[0] for ei in embed_input]
self.cat_idx = [self.column_idx[col] for col in categorical_cols]
......@@ -211,7 +228,7 @@ class SameSizeCatEmbeddings(nn.Module):
self.embed: Union[nn.ModuleDict, nn.Embedding] = nn.ModuleDict(
{
"emb_layer_"
+ col: SharedEmbeddings(
+ self.embed_layers_names[col]: SharedEmbeddings(
val if col == "cls_token" else val + 1,
embed_dim,
embed_dropout,
......@@ -233,9 +250,11 @@ class SameSizeCatEmbeddings(nn.Module):
def forward(self, X: Tensor) -> Tensor:
if self.shared_embed:
cat_embed = [
self.embed["emb_layer_" + col]( # type: ignore[index]
self.embed["emb_layer_" + self.embed_layers_names[col]]( # type: ignore[index]
X[:, self.column_idx[col]].long()
).unsqueeze(1)
).unsqueeze(
1
)
for col, _ in self.embed_input
]
x = torch.cat(cat_embed, 1)
......
......@@ -405,6 +405,43 @@ def test_get_embeddings_deprecation_warning():
)
###############################################################################
# test test_handle_columns_with_dots
###############################################################################
def test_handle_columns_with_dots():
data = df.copy()
data = data.rename(columns={"col1": "col.1", "a": "a.1"})
embed_cols = [("col.1", 5), ("col2", 5)]
continuous_cols = ["col3", "col4"]
tab_preprocessor = TabPreprocessor(
cat_embed_cols=embed_cols, continuous_cols=continuous_cols
)
X_tab = tab_preprocessor.fit_transform(data)
target = data.target.values
tabmlp = TabMlp(
mlp_hidden_dims=[32, 16],
mlp_dropout=[0.5, 0.5],
column_idx={k: v for v, k in enumerate(data.columns)},
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
)
model = WideDeep(deeptabular=tabmlp)
trainer = Trainer(model, objective="binary", verbose=0)
trainer.fit(
X_tab=X_tab,
target=target,
batch_size=16,
)
preds = trainer.predict(X_tab=X_tab, batch_size=16)
assert preds.shape[0] == 32 and "train_loss" in trainer.history
###############################################################################
# test Label Distribution Smoothing
###############################################################################
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册