提交 fd3e4e34 编写于 作者: J jrzaurin

Added tests for two minor changes in the TabPreprocessor and the Trainer

上级 cee8e1c7
......@@ -295,7 +295,9 @@ class TabPreprocessor(BasePreprocessor):
and self.continuous_cols is not None
and len(np.intersect1d(self.cat_embed_cols, self.continuous_cols)) > 0
):
overlapping_cols = list(np.intersect1d(cat_embed_cols, continuous_cols))
overlapping_cols = list(
np.intersect1d(self.cat_embed_cols, self.continuous_cols)
)
raise ValueError(
"Currently passing columns as both categorical and continuum is not supported."
" Please, choose one or the other for the following columns: {}".format(
......
......@@ -294,3 +294,19 @@ def test_embed_sz_rule_of_thumb(rule):
tab_preprocessor.embed_dim[col] == embed_szs[col] for col in embed_szs.keys()
]
assert all(out)
###############################################################################
# Test Valuerror for repeated cols
###############################################################################
def test_overlapping_cols_valueerror():
embed_cols = ["col1", "col2"]
cont_cols = ["col1", "col2"]
with pytest.raises(ValueError):
tab_preprocessor = TabPreprocessor( # noqa: F841
cat_embed_cols=embed_cols, continuous_cols=cont_cols
)
......@@ -300,3 +300,23 @@ def test_custom_dataloader():
)
# simply checking that runs with DataLoaderImbalanced
assert "train_loss" in trainer.history.keys()
##############################################################################
# Test raise warning for multiclass classification
##############################################################################
def test_multiclass_warning():
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[32, 16],
mlp_dropout=[0.5, 0.5],
)
model = WideDeep(wide=wide, deeptabular=deeptabular)
with pytest.raises(ValueError):
trainer = Trainer(model, loss="multiclass", verbose=0) # noqa: F841
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册