提交 a4770eae 编写于 作者: J jrzaurin

re-organized the transformer module and add unit test

上级 b1b8d8bb
......@@ -108,7 +108,7 @@ if __name__ == "__main__":
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
input_dim=64,
dim_k=32,
dim_k=6,
n_blocks=3,
n_heads=4,
)
......
......@@ -17,7 +17,7 @@ def create_explain_matrix(model: WideDeep) -> csc_matrix:
Examples
--------
>>> from pytorch_widedeep.models import TabNet, WideDeep
>>> from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix
>>> from pytorch_widedeep.models.tabnet._utils import create_explain_matrix
>>> embed_input = [("a", 4, 2), ("b", 4, 2), ("c", 4, 2)]
>>> cont_cols = ["d", "e"]
>>> column_idx = {k: v for v, k in enumerate(["a", "b", "c", "d", "e"])}
......
......@@ -3,7 +3,7 @@ from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tab_mlp import CatEmbeddingsAndCont
from pytorch_widedeep.models.tabnet.layers import (
from pytorch_widedeep.models.tabnet._layers import (
TabNetEncoder,
initialize_non_glu,
)
......
......@@ -135,7 +135,7 @@ class CategoricalEmbeddings(nn.Module):
# Categorical: val + 1 because 0 is reserved for padding/unseen cateogories.
if self.shared_embed:
self.cat_embed: Union[nn.ModuleDict, nn.Embedding] = nn.ModuleDict(
self.embed: Union[nn.ModuleDict, nn.Embedding] = nn.ModuleDict(
{
"emb_layer_"
+ col: SharedEmbeddings(
......@@ -150,31 +150,29 @@ class CategoricalEmbeddings(nn.Module):
}
)
else:
self.cat_embed = nn.Embedding(self.n_tokens + 1, embed_dim, padding_idx=0)
self.embed = nn.Embedding(self.n_tokens + 1, embed_dim, padding_idx=0)
if full_embed_dropout:
self.embedding_dropout: DropoutLayers = FullEmbeddingDropout(
embed_dropout
)
self.dropout: DropoutLayers = FullEmbeddingDropout(embed_dropout)
else:
self.embedding_dropout = nn.Dropout(embed_dropout)
self.dropout = nn.Dropout(embed_dropout)
def forward(self, X: Tensor) -> Tensor:
if self.shared_embed:
cat_embed = [
self.cat_embed["emb_layer_" + col]( # type: ignore[index]
self.embed["emb_layer_" + col]( # type: ignore[index]
X[:, self.column_idx[col]].long()
).unsqueeze(1)
for col, _ in self.embed_input
]
x_cat = torch.cat(cat_embed, 1)
x = torch.cat(cat_embed, 1)
else:
x_cat = self.cat_embed(X[:, self.cat_idx].long())
x_cat = self.embedding_dropout(x_cat)
x = self.embed(X[:, self.cat_idx].long())
x = self.dropout(x)
if self.bias is not None:
x_cat = x_cat + self.bias.unsqueeze(0)
x = x + self.bias.unsqueeze(0)
return x_cat
return x
class CatAndContEmbeddings(nn.Module):
......
......@@ -2,7 +2,7 @@ import einops
from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.transformers.attention_layers import (
from pytorch_widedeep.models.transformers._attention_layers import (
AddNorm,
NormAdd,
PositionwiseFF,
......
......@@ -2,8 +2,8 @@ from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
from pytorch_widedeep.models.transformers.encoders import FTTransformerEncoder
from pytorch_widedeep.models.transformers.embedding_layers import (
from pytorch_widedeep.models.transformers._encoders import FTTransformerEncoder
from pytorch_widedeep.models.transformers._embeddings_layers import (
CatAndContEmbeddings,
)
......
......@@ -2,8 +2,8 @@ from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
from pytorch_widedeep.models.transformers.encoders import SaintEncoder
from pytorch_widedeep.models.transformers.embedding_layers import (
from pytorch_widedeep.models.transformers._encoders import SaintEncoder
from pytorch_widedeep.models.transformers._embeddings_layers import (
CatAndContEmbeddings,
)
......
......@@ -2,8 +2,8 @@ from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
from pytorch_widedeep.models.transformers.encoders import FastFormerEncoder
from pytorch_widedeep.models.transformers.embedding_layers import (
from pytorch_widedeep.models.transformers._encoders import FastFormerEncoder
from pytorch_widedeep.models.transformers._embeddings_layers import (
CatAndContEmbeddings,
)
......@@ -65,10 +65,10 @@ class TabFastFormer(nn.Module):
Dropout that will be applied to the MultiHeadAttention module
ff_dropout: float, default = 0.1
Dropout that will be applied to the FeedForward network
share_qv_weights: bool, default = True
share_qv_weights: bool, default = False
Following the original publication, this is a boolean indicating if
the the value and query transformation parameters will be shared
share_weights: bool, default = True
share_weights: bool, default = False
In addition to sharing the value and query transformation parameters,
the parameters across different Fastformer layers are also shared in
the paper.
......@@ -138,8 +138,8 @@ class TabFastFormer(nn.Module):
n_blocks: int = 6,
attn_dropout: float = 0.1,
ff_dropout: float = 0.2,
share_qv_weights: bool = True,
share_weights: bool = True,
share_qv_weights: bool = False,
share_weights: bool = False,
transformer_activation: str = "relu",
mlp_hidden_dims: Optional[List[int]] = None,
mlp_activation: str = "relu",
......
......@@ -4,8 +4,8 @@ from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
from pytorch_widedeep.models.transformers.encoders import PerceiverEncoder
from pytorch_widedeep.models.transformers.embedding_layers import (
from pytorch_widedeep.models.transformers._encoders import PerceiverEncoder
from pytorch_widedeep.models.transformers._embeddings_layers import (
CatAndContEmbeddings,
)
......@@ -82,7 +82,7 @@ class TabPerceiver(nn.Module):
n_perceiver_blocks: int, default = 4
Number of Perceiver blocks defined as [Cross Attention -> Latent
Transformer]
share_weights: Boolean, default = True
share_weights: Boolean, default = False
Boolean indicating if the weights will be shared between Perceiver
blocks
attn_dropout: float, default = 0.2
......@@ -161,7 +161,7 @@ class TabPerceiver(nn.Module):
n_latent_heads: int = 4,
n_latent_blocks: int = 4,
n_perceiver_blocks: int = 4,
share_weights: bool = True,
share_weights: bool = False,
attn_dropout: float = 0.1,
ff_dropout: float = 0.1,
transformer_activation: str = "geglu",
......
......@@ -3,8 +3,8 @@ from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
from pytorch_widedeep.models.transformers.encoders import TransformerEncoder
from pytorch_widedeep.models.transformers.embedding_layers import (
from pytorch_widedeep.models.transformers._encoders import TransformerEncoder
from pytorch_widedeep.models.transformers._embeddings_layers import (
CatAndContEmbeddings,
)
......@@ -266,7 +266,10 @@ class TabTransformer(nn.Module):
def _compute_attn_output_dim(self) -> int:
if self.with_cls_token:
attn_output_dim = self.input_dim
if self.embed_continuous:
attn_output_dim = self.input_dim
else:
attn_output_dim = self.input_dim + self.n_cont
elif self.embed_continuous:
attn_output_dim = (self.n_cat + self.n_cont) * self.input_dim
else:
......
......@@ -192,7 +192,7 @@ def alias_to_loss(loss_fn: str, **kwargs):
Examples
--------
>>> from pytorch_widedeep.training.trainer_utils import alias_to_loss
>>> from pytorch_widedeep.training._trainer_utils import alias_to_loss
>>> loss_fn = alias_to_loss(loss_fn="binary_logloss", weight=None)
"""
if loss_fn not in _ObjectiveToMethod.keys():
......
......@@ -25,14 +25,14 @@ from pytorch_widedeep.dataloaders import DataLoaderDefault
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
from pytorch_widedeep.training._finetune import FineTune
from pytorch_widedeep.utils.general_utils import Alias
from pytorch_widedeep.models.tabnet._utils import create_explain_matrix
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training.trainer_utils import (
from pytorch_widedeep.training._trainer_utils import (
alias_to_loss,
save_epoch_logs,
wd_train_val_split,
print_loss_and_metric,
)
from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod
......
......@@ -5,8 +5,8 @@ import torch
import pytest
from pytorch_widedeep.wdtypes import WideDeep
from pytorch_widedeep.models.tabnet._utils import create_explain_matrix
from pytorch_widedeep.models.tabnet.tab_net import TabNet # noqa: F403
from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix
# I am going over test this model due to the number of components
......
......@@ -8,10 +8,12 @@ import pytest
from pytorch_widedeep.models import (
SAINT,
TabPerceiver,
FTTransformer,
TabFastFormer,
TabTransformer,
)
from pytorch_widedeep.models.transformers.layers import * # noqa: F403
from pytorch_widedeep.models.transformers._attention_layers import * # noqa: F403
from pytorch_widedeep.models.transformers._embeddings_layers import * # noqa: F403
# I am going over test these models due to the number of components
......@@ -44,10 +46,12 @@ model1 = TabTransformer(
def test_embeddings_have_padding():
res = []
res.append(
model1.cat_embed_and_cont.cat_embed.weight.size(0)
== model1.cat_embed_and_cont.n_tokens + 1
model1.cat_and_cont_embed.cat_embed.embed.weight.size(0)
== model1.cat_and_cont_embed.cat_embed.n_tokens + 1
)
res.append(
not torch.all(model1.cat_and_cont_embed.cat_embed.embed.weight[0].bool())
)
res.append(not torch.all(model1.cat_embed_and_cont.cat_embed.weight[0].bool()))
assert all(res)
......@@ -104,7 +108,7 @@ model2 = TabTransformer(
def test_shared_embeddings_have_padding():
res = []
for k, v in model2.cat_embed_and_cont.cat_embed.items():
for k, v in model2.cat_and_cont_embed.cat_embed.embed.items():
res.append(v.embed.weight.size(0) == n_embed + 1)
res.append(not torch.all(v.embed.weight[0].bool()))
assert all(res)
......@@ -128,7 +132,7 @@ def test_continuous_embeddings():
X = torch.rand(bsz, n_cont_cols)
cont_embed = ContinuousEmbeddings(
n_cont_cols=n_cont_cols, embed_dim=embed_dim, activation=None, bias=None
n_cont_cols=n_cont_cols, embed_dim=embed_dim, activation=None, use_bias=False
)
out = cont_embed(X)
res = (
......@@ -177,6 +181,19 @@ def test_full_embed_dropout():
# ###############################################################################
def _build_model(model_name, params):
if model_name == "tabtransformer":
return TabTransformer(n_blocks=2, n_heads=2, **params)
if model_name == "saint":
return SAINT(n_blocks=2, n_heads=2, **params)
if model_name == "fttransformer":
return FTTransformer(n_blocks=2, n_heads=2, dim_k=2, **params)
if model_name == "tabfastformer":
return TabFastFormer(n_blocks=2, n_heads=2, **params)
if model_name == "tabperceiver":
return TabPerceiver(n_perceiver_blocks=2, n_latents=2, latent_dim=16, **params)
@pytest.mark.parametrize(
"embed_continuous, with_cls_token, model_name",
[
......@@ -184,14 +201,6 @@ def test_full_embed_dropout():
(True, False, "tabtransformer"),
(False, True, "tabtransformer"),
(False, False, "tabtransformer"),
(True, True, "saint"),
(True, False, "saint"),
(False, True, "saint"),
(False, False, "saint"),
(True, True, "tabfastformer"),
(True, False, "tabfastformer"),
(False, True, "tabfastformer"),
(False, False, "tabfastformer"),
],
)
def test_embed_continuous_and_with_cls_token(
......@@ -207,48 +216,30 @@ def test_embed_continuous_and_with_cls_token(
n_colnames = copy(colnames)
cont_idx = n_cols
if model_name == "tabtransformer":
model = TabTransformer(
column_idx={k: v for v, k in enumerate(n_colnames)},
embed_input=with_cls_token_embed_input if with_cls_token else embed_input,
continuous_cols=n_colnames[cont_idx:],
embed_continuous=embed_continuous,
n_blocks=4,
)
elif model_name == "saint":
model = SAINT(
column_idx={k: v for v, k in enumerate(n_colnames)},
embed_input=with_cls_token_embed_input if with_cls_token else embed_input,
continuous_cols=n_colnames[cont_idx:],
embed_continuous=embed_continuous,
n_blocks=4,
)
elif model_name == "tabfastformer":
model = TabFastFormer(
column_idx={k: v for v, k in enumerate(n_colnames)},
embed_input=with_cls_token_embed_input if with_cls_token else embed_input,
continuous_cols=n_colnames[cont_idx:],
embed_continuous=embed_continuous,
n_blocks=4,
n_heads=4,
share_qv_weights=False,
share_weights=False,
)
params = {
"column_idx": {k: v for v, k in enumerate(n_colnames)},
"embed_input": with_cls_token_embed_input if with_cls_token else embed_input,
"continuous_cols": n_colnames[cont_idx:],
"embed_continuous": embed_continuous,
}
model = _build_model(model_name, params)
out = model(X)
res = [out.size(0) == 10]
if with_cls_token:
if embed_continuous:
res.append(model._set_mlp_hidden_dims()[0] == model.input_dim)
res.append(model._compute_attn_output_dim() == model.input_dim)
else:
res.append(
model._set_mlp_hidden_dims()[0] == model.input_dim + len(cont_cols)
model._compute_attn_output_dim() == model.input_dim + len(cont_cols)
)
elif embed_continuous:
mlp_first_h = X.shape[1] * model.input_dim
res.append(model._set_mlp_hidden_dims()[0] == mlp_first_h)
res.append(model._compute_attn_output_dim() == mlp_first_h)
else:
mlp_first_h = len(embed_cols) * model.input_dim + 2
res.append(model._set_mlp_hidden_dims()[0] == mlp_first_h)
res.append(model._compute_attn_output_dim() == mlp_first_h)
assert all(res)
......@@ -257,49 +248,38 @@ def test_embed_continuous_and_with_cls_token(
"activation, model_name",
[
("tanh", "tabtransformer"),
("relu", "tabtransformer"),
("leaky_relu", "tabtransformer"),
("gelu", "tabtransformer"),
("geglu", "tabtransformer"),
("reglu", "tabtransformer"),
("tanh", "saint"),
("relu", "saint"),
("leaky_relu", "saint"),
("gelu", "saint"),
("geglu", "saint"),
("reglu", "saint"),
("tanh", "fttransformer"),
("leaky_relu", "fttransformer"),
("geglu", "fttransformer"),
("reglu", "fttransformer"),
("tanh", "tabfastformer"),
("leaky_relu", "tabfastformer"),
("geglu", "tabfastformer"),
("reglu", "tabfastformer"),
("tanh", "tabperceiver"),
("relu", "tabperceiver"),
("leaky_relu", "tabperceiver"),
("gelu", "tabperceiver"),
("geglu", "tabperceiver"),
("reglu", "tabperceiver"),
],
)
def test_transformer_activations(activation, model_name):
if model_name == "tabtransformer":
model = TabTransformer(
column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[n_cols:],
transformer_activation=activation,
)
elif model_name == "saint":
model = SAINT(
column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[n_cols:],
transformer_activation=activation,
)
elif model_name == "tabperceiver":
model = TabPerceiver(
column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[n_cols:],
transformer_activation=activation,
n_latents=2,
latent_dim=16,
n_perceiver_blocks=2,
share_weights=False,
)
params = {
"column_idx": {k: v for v, k in enumerate(colnames)},
"embed_input": embed_input,
"continuous_cols": colnames[n_cols:],
"transformer_activation": activation,
}
model = _build_model(model_name, params)
out = model(X_tab)
assert out.size(0) == 10
......@@ -314,56 +294,32 @@ def test_transformer_activations(activation, model_name):
[
"tabtransformer",
"saint",
"tabperceiver",
"fttransformer",
"tabfastformer",
"tabperceiver",
],
)
def test_transformers_keep_attn(model_name):
if model_name == "tabtransformer":
model = TabTransformer(
column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[n_cols:],
n_blocks=2,
)
elif model_name == "saint":
model = SAINT(
column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[n_cols:],
embed_continuous=False,
n_blocks=2,
)
elif model_name == "tabperceiver":
model = TabPerceiver(
column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[n_cols:],
n_latents=2,
latent_dim=16,
n_perceiver_blocks=2,
share_weights=False,
)
elif model_name == "tabfastformer":
model = TabFastFormer(
column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[n_cols:],
embed_continuous=False,
n_blocks=2,
n_heads=4,
share_qv_weights=False,
share_weights=False,
)
params = {
"column_idx": {k: v for v, k in enumerate(colnames)},
"embed_input": embed_input,
"continuous_cols": colnames[n_cols:],
}
# n_cols is an unfortunate name I might change in the future. It refers to
# the number of cat and cont cols, so the total number of cols is
# n_cols * 2
total_n_cols = n_cols * 2
model = _build_model(model_name, params)
out = model(X_tab)
res = [out.size(0) == 10]
if model_name != "tabperceiver":
res.append(out.size(1) == model._set_mlp_hidden_dims()[-1])
res.append(len(model.attention_weights) == model.n_blocks)
else:
res.append(out.size(1) == model.mlp_hidden_dims[-1])
res.append(len(model.attention_weights) == model.n_perceiver_blocks)
if model_name == "tabtransformer":
......@@ -374,11 +330,16 @@ def test_transformers_keep_attn(model_name):
elif model_name == "saint":
res.append(
list(model.attention_weights[0][0].shape)
== [10, model.n_heads, n_cols, n_cols]
== [10, model.n_heads, total_n_cols, total_n_cols]
)
res.append(
list(model.attention_weights[0][1].shape)
== [1, model.n_heads, n_cols * n_embed, n_cols * n_embed]
== [1, model.n_heads, X_tab.shape[0], X_tab.shape[0]]
)
if model_name == "fttransformer":
res.append(
list(model.attention_weights[0].shape)
== [10, model.n_heads, total_n_cols, model.dim_k]
)
elif model_name == "tabperceiver":
res.append(
......@@ -395,9 +356,11 @@ def test_transformers_keep_attn(model_name):
)
elif model_name == "tabfastformer":
res.append(
list(model.attention_weights[0][0].shape) == [10, model.n_heads, n_cols]
list(model.attention_weights[0][0].shape)
== [10, model.n_heads, total_n_cols]
)
res.append(
list(model.attention_weights[0][1].shape) == [10, model.n_heads, n_cols]
list(model.attention_weights[0][1].shape)
== [10, model.n_heads, total_n_cols]
)
assert all(res)
......@@ -13,6 +13,7 @@ from pytorch_widedeep.models import (
WideDeep,
TabResnet,
TabPerceiver,
FTTransformer,
TabFastFormer,
TabTransformer,
)
......@@ -78,6 +79,33 @@ def test_non_transformer_models(deeptabular):
assert X_vec.shape[1] == embed_dim + cont_dim
###############################################################################
# Test Transformer models
###############################################################################
def _build_model(model_name, params):
if model_name == "tabtransformer":
return TabTransformer(input_dim=8, n_heads=2, n_blocks=2, **params)
if model_name == "saint":
return SAINT(input_dim=8, n_heads=2, n_blocks=2, **params)
if model_name == "fttransformer":
return FTTransformer(n_blocks=2, n_heads=2, dim_k=2, **params)
if model_name == "tabfastformer":
return TabFastFormer(n_blocks=2, n_heads=2, **params)
if model_name == "tabperceiver":
return TabPerceiver(
input_dim=8,
n_cross_attn_heads=2,
n_latents=2,
latent_dim=8,
n_latent_heads=2,
n_perceiver_blocks=2,
share_weights=False,
**params
)
@pytest.mark.parametrize(
"model_name, with_cls_token, share_embeddings, embed_continuous",
[
......@@ -85,26 +113,9 @@ def test_non_transformer_models(deeptabular):
("tabtransformer", True, False, False),
("tabtransformer", False, True, False),
("tabtransformer", True, False, True),
("saint", False, False, True),
("saint", True, False, True),
("saint", False, True, True),
("saint", True, False, True),
(
"tabperceiver",
False,
False,
True,
), # embed_continuous is irrelevant for the perceiver
("tabperceiver", True, False, True),
("tabperceiver", False, True, True),
("tabperceiver", True, False, True),
("tabfastformer", False, False, True),
("tabfastformer", True, False, True),
("tabfastformer", False, True, True),
("tabfastformer", True, False, True),
],
)
def test_transformer_models(
def test_tab_transformer_models(
model_name, with_cls_token, share_embeddings, embed_continuous
):
......@@ -120,50 +131,14 @@ def test_transformer_models(
)
X_tab = tab_preprocessor.fit_transform(df_init) # noqa: F841
if model_name == "tabtransformer":
deeptabular = TabTransformer(
column_idx=tab_preprocessor.column_idx,
embed_input=tab_preprocessor.embeddings_input,
continuous_cols=tab_preprocessor.continuous_cols,
embed_continuous=embed_continuous,
input_dim=8,
n_heads=2,
n_blocks=2,
)
elif model_name == "saint":
deeptabular = SAINT(
column_idx=tab_preprocessor.column_idx,
embed_input=tab_preprocessor.embeddings_input,
continuous_cols=tab_preprocessor.continuous_cols,
embed_continuous=True,
input_dim=8,
n_heads=2,
n_blocks=2,
)
elif model_name == "tabperceiver":
deeptabular = TabPerceiver(
column_idx=tab_preprocessor.column_idx,
embed_input=tab_preprocessor.embeddings_input,
continuous_cols=tab_preprocessor.continuous_cols,
input_dim=8,
n_cross_attn_heads=2,
n_latents=2,
latent_dim=8,
n_latent_heads=2,
n_perceiver_blocks=2,
share_weights=False,
)
elif model_name == "tabfastformer":
deeptabular = TabFastFormer(
column_idx=tab_preprocessor.column_idx,
embed_input=tab_preprocessor.embeddings_input,
continuous_cols=tab_preprocessor.continuous_cols,
embed_continuous=embed_continuous,
n_blocks=2,
n_heads=4,
share_qv_weights=False,
share_weights=False,
)
params = {
"column_idx": tab_preprocessor.column_idx,
"embed_input": tab_preprocessor.embeddings_input,
"continuous_cols": tab_preprocessor.continuous_cols,
"embed_continuous": embed_continuous,
}
deeptabular = _build_model(model_name, params)
# Let's assume the model is trained
model = WideDeep(deeptabular=deeptabular)
......@@ -176,3 +151,55 @@ def test_transformer_models(
out_dim = len(embed_cols) * deeptabular.input_dim + len(cont_cols)
assert X_vec.shape[1] == out_dim
@pytest.mark.parametrize(
"model_name, with_cls_token, share_embeddings",
[
("saint", False, True),
("saint", True, True),
("saint", False, False),
("fttransformer", False, True),
("fttransformer", True, True),
("fttransformer", False, False),
("tabfastformer", False, True),
("tabfastformer", True, True),
("tabfastformer", False, False),
(
"tabperceiver",
False,
True,
), # for the perceiver we do not need with_cls_token
("tabperceiver", False, False),
],
)
def test_transformer_family_models(model_name, with_cls_token, share_embeddings):
embed_cols = ["a", "b"]
cont_cols = ["c", "d"]
tab_preprocessor = TabPreprocessor(
embed_cols=embed_cols,
continuous_cols=cont_cols,
for_transformer=True,
with_cls_token=with_cls_token,
shared_embed=share_embeddings,
)
X_tab = tab_preprocessor.fit_transform(df_init) # noqa: F841
params = {
"column_idx": tab_preprocessor.column_idx,
"embed_input": tab_preprocessor.embeddings_input,
"continuous_cols": tab_preprocessor.continuous_cols,
}
deeptabular = _build_model(model_name, params)
# Let's assume the model is trained
model = WideDeep(deeptabular=deeptabular)
t2v = Tab2Vec(model, tab_preprocessor)
X_vec = t2v.transform(df_t2v)
out_dim = (len(embed_cols) + len(cont_cols)) * deeptabular.input_dim
assert X_vec.shape[1] == out_dim
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册