提交 fadede26 编写于 作者: J jrzaurin

Fixed issue #53 related to the use of some transformer models without categorical columns

上级 6540cd3c
......@@ -11,6 +11,7 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
# pytorch-widedeep
......
1.0.9
\ No newline at end of file
1.0.10
\ No newline at end of file
......@@ -6,6 +6,7 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
# pytorch-widedeep
......
......@@ -134,7 +134,7 @@ class FTTransformer(nn.Module):
def __init__(
self,
column_idx: Dict[str, int],
embed_input: List[Tuple[str, int]],
embed_input: Optional[List[Tuple[str, int]]] = None,
embed_dropout: float = 0.1,
full_embed_dropout: bool = False,
shared_embed: bool = False,
......@@ -194,11 +194,6 @@ class FTTransformer(nn.Module):
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
self.n_feats = self.n_cat + self.n_cont
if self.n_cont and not self.n_cat and not self.embed_continuous:
raise ValueError(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)
self.cat_and_cont_embed = CatAndContEmbeddings(
input_dim,
column_idx,
......
......@@ -120,7 +120,7 @@ class SAINT(nn.Module):
def __init__(
self,
column_idx: Dict[str, int],
embed_input: List[Tuple[str, int]],
embed_input: Optional[List[Tuple[str, int]]] = None,
embed_dropout: float = 0.1,
full_embed_dropout: bool = False,
shared_embed: bool = False,
......@@ -173,11 +173,6 @@ class SAINT(nn.Module):
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
self.n_feats = self.n_cat + self.n_cont
if self.n_cont and not self.n_cat and not self.embed_continuous:
raise ValueError(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)
self.cat_and_cont_embed = CatAndContEmbeddings(
input_dim,
column_idx,
......
......@@ -182,11 +182,6 @@ class TabFastFormer(nn.Module):
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
self.n_feats = self.n_cat + self.n_cont
if self.n_cont and not self.n_cat and not self.embed_continuous:
raise ValueError(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)
self.cat_and_cont_embed = CatAndContEmbeddings(
input_dim,
column_idx,
......
__version__ = "1.0.9"
__version__ = "1.0.10"
......@@ -449,3 +449,34 @@ def test_ft_transformer_mlp(mlp_first_h, shoud_work):
else:
with pytest.raises(AssertionError):
model = _build_model("fttransformer", params) # noqa: F841
###############################################################################
# Test transformers with only continuous cols
###############################################################################
X_tab_only_cont = torch.from_numpy(
np.vstack([np.random.rand(10) for _ in range(4)]).transpose()
)
colnames_only_cont = list(string.ascii_lowercase)[:4]
@pytest.mark.parametrize(
"model_name",
[
"fttransformer",
"saint",
"tabfastformer",
],
)
def test_transformers_only_cont(model_name):
params = {
"column_idx": {k: v for v, k in enumerate(colnames_only_cont)},
"continuous_cols": colnames_only_cont,
}
model = _build_model(model_name, params)
out = model(X_tab_only_cont)
assert out.size(0) == 10 and out.size(1) == model.output_dim
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册