提交 feb9130a 编写于 作者: J Javier Rodriguez Zaurin

re-wrote a test and fixed a couple of typing errors

上级 1767d7e0
......@@ -124,7 +124,7 @@ class ContrastiveDenoisingModel(nn.Module):
) -> Tuple[Union[MLP, nn.Identity], Union[MLP, nn.Identity]]:
if self.projection_head1_dims is not None:
projection_head1 = MLP(
projection_head1: Union[MLP, nn.Identity] = MLP(
d_hidden=self.projection_head1_dims,
activation=self.projection_heads_activation,
dropout=0.0,
......@@ -133,7 +133,7 @@ class ContrastiveDenoisingModel(nn.Module):
linear_first=False,
)
if self.projection_head2_dims is not None:
projection_head2 = MLP(
projection_head2: Union[MLP, nn.Identity] = MLP(
d_hidden=self.projection_head2_dims,
activation=self.projection_heads_activation,
dropout=0.0,
......@@ -155,7 +155,7 @@ class ContrastiveDenoisingModel(nn.Module):
if self.use_cat_mlp:
if self.cat_mlp_type == "single":
denoise_cat_mlp = CatSingleMlp(
denoise_cat_mlp: Union[CatSingleMlp, CatMlpPerFeature] = CatSingleMlp(
self.model.input_dim,
self.model.cat_embed_input,
self.model.column_idx,
......@@ -173,7 +173,9 @@ class ContrastiveDenoisingModel(nn.Module):
if self.model.continuous_cols is not None:
if self.cont_mlp_type == "single":
denoise_cont_mlp = ContSingleMlp(
denoise_cont_mlp: Union[
ContSingleMlp, ContMlpPerFeature
] = ContSingleMlp(
self.model.input_dim,
self.model.continuous_cols,
self.model.column_idx,
......@@ -191,12 +193,10 @@ class ContrastiveDenoisingModel(nn.Module):
return denoise_cat_mlp, denoise_cont_mlp
def _set_idx_to_substract(
self, encoding_dict
) -> Optional[Dict[str, Dict[str, int]]]:
def _set_idx_to_substract(self, encoding_dict) -> Optional[Dict[str, int]]:
if self.cat_mlp_type == "multiple":
idx_to_substract: Optional[Dict[str, Dict[str, int]]] = {
idx_to_substract: Optional[Dict[str, int]] = {
k: min(sorted(list(v.values()))) for k, v in encoding_dict.items()
}
......
......@@ -330,7 +330,7 @@ class TabNetDecoder(nn.Module):
initialize_non_glu(self.reconstruction_layer, step_dim, embed_dim)
def forward(self, X: List[Tensor]) -> Tensor:
out = 0.0
out = torch.tensor(0.0)
for i, x in enumerate(X):
x = self.decoder[i](x)
out = torch.add(out, x)
......
......@@ -125,7 +125,7 @@ test_df = pd.DataFrame(
"with_cls_token",
[True, False],
)
def test_cont_den_different_setups( # noqa: C901
def test_cont_den_multiple_mlps_different_setups(
transf_model, cat_or_cont, with_cls_token
):
......@@ -147,6 +147,65 @@ def test_cont_den_different_setups( # noqa: C901
else None
)
tr_model = _build_transf_model(
transf_model, preprocessor, cat_embed_input, continuous_cols
)
cd_model = ContrastiveDenoisingModel(
tr_model,
preprocessor,
loss_type="both",
projection_head1_dims=None,
projection_head2_dims=None,
projection_heads_activation="relu",
cat_mlp_type="multiple",
cont_mlp_type="multiple",
denoise_mlps_activation="relu",
)
if cat_or_cont in ["cat", "both"]:
out_dim = []
for name, param in cd_model.denoise_cat_mlp.named_parameters():
if ("dense_layer_1.0" in name) and ("weight" in name):
out_dim.append(param.shape[0])
g_projs, cat_x_and_x_, cont_x_and_x_ = cd_model(X)
checks = []
if g_projs is not None:
projs_check = _check_g_projs(X, g_projs, tr_model, with_cls_token)
checks.extend([projs_check])
if cat_x_and_x_ is not None:
cat_checks = _check_cat_multiple_denoise_mlps(
X, cat_x_and_x_, with_cls_token, out_dim
)
checks.extend([cat_checks])
if cat_or_cont == "both":
cont_if_cat_check = _check_cont_if_cat_multiple_denoise_mlps(
X, cont_x_and_x_, with_cls_token
)
checks.extend([cont_if_cat_check])
elif cat_or_cont == "cont":
cont_only_check = _check_cont_only_multiple_denoise_mlps(
X, cont_x_and_x_, with_cls_token
)
checks.extend([cont_only_check])
assert all(checks)
def _build_transf_model(transf_model, preprocessor, cat_embed_input, continuous_cols):
if transf_model == "tabtransformer":
model = TabTransformer(
column_idx=preprocessor.column_idx,
......@@ -181,70 +240,70 @@ def test_cont_den_different_setups( # noqa: C901
n_heads=2,
)
cd_model = ContrastiveDenoisingModel(
model,
preprocessor,
loss_type="both",
projection_head1_dims=None,
projection_head2_dims=None,
projection_heads_activation="relu",
cat_mlp_type="multiple",
cont_mlp_type="multiple",
denoise_mlps_activation="relu",
)
return model
if cat_or_cont in ["cat", "both"]:
out_dim = []
for name, param in cd_model.denoise_cat_mlp.named_parameters():
if ("dense_layer_1.0" in name) and ("weight" in name):
out_dim.append(param.shape[0])
g_projs, cat_x_and_x_, cont_x_and_x_ = cd_model(X)
def _check_g_projs(X, g_projs, model, with_cls_token):
assertions = []
if g_projs is not None:
asrt1 = g_projs[0].shape[1] == X.shape[1] - 1 if with_cls_token else X.shape[1]
asrt2 = g_projs[1].shape[1] == X.shape[1] - 1 if with_cls_token else X.shape[1]
asrt3 = g_projs[0].shape[2] == model.input_dim
asrt4 = g_projs[1].shape[2] == model.input_dim
assertions.extend([asrt1, asrt2, asrt3, asrt4])
asrt1 = g_projs[0].shape[1] == X.shape[1] - 1 if with_cls_token else X.shape[1]
asrt2 = g_projs[1].shape[1] == X.shape[1] - 1 if with_cls_token else X.shape[1]
asrt3 = g_projs[0].shape[2] == model.input_dim
asrt4 = g_projs[1].shape[2] == model.input_dim
if cat_x_and_x_ is not None:
assertions.extend([asrt1, asrt2, asrt3, asrt4])
targ1 = (X[:, 1] - 2).long() if with_cls_token else (X[:, 0] - 1).long()
idx_to_substract_col2 = min(X[:, 2]) if with_cls_token else min(X[:, 1])
targ2 = (
(X[:, 2] - idx_to_substract_col2).long()
if with_cls_token
else (X[:, 1] - idx_to_substract_col2).long()
)
return all(assertions)
assrt5 = all(cat_x_and_x_[0][0] == targ1)
assrt6 = all(cat_x_and_x_[1][0] == targ2)
assrt7 = cat_x_and_x_[0][1].shape[1] == out_dim[0]
assrt8 = cat_x_and_x_[1][1].shape[1] == out_dim[1]
assertions.extend([assrt5, assrt6, assrt7, assrt8])
def _check_cat_multiple_denoise_mlps(X, cat_x_and_x_, with_cls_token, out_dim):
if cat_or_cont == "both":
assertions = []
targ1 = X[:, 3] if with_cls_token else X[:, 2]
targ2 = X[:, 4] if with_cls_token else X[:, 3]
targ1 = (X[:, 1] - 2).long() if with_cls_token else (X[:, 0] - 1).long()
idx_to_substract_col2 = min(X[:, 2]) if with_cls_token else min(X[:, 1])
targ2 = (
(X[:, 2] - idx_to_substract_col2).long()
if with_cls_token
else (X[:, 1] - idx_to_substract_col2).long()
)
assrt9 = all(torch.isclose(cont_x_and_x_[0][0].squeeze(1), targ1.float()))
assrt10 = all(torch.isclose(cont_x_and_x_[1][0].squeeze(1), targ2.float()))
assrt1 = all(cat_x_and_x_[0][0] == targ1)
assrt2 = all(cat_x_and_x_[1][0] == targ2)
assrt3 = cat_x_and_x_[0][1].shape[1] == out_dim[0]
assrt4 = cat_x_and_x_[1][1].shape[1] == out_dim[1]
assertions.extend([assrt9, assrt10])
assertions.extend([assrt1, assrt2, assrt3, assrt4])
elif cat_or_cont == "cont":
return assertions
def _check_cont_if_cat_multiple_denoise_mlps(X, cont_x_and_x_, with_cls_token):
assertions = []
targ1 = X[:, 3] if with_cls_token else X[:, 2]
targ2 = X[:, 4] if with_cls_token else X[:, 3]
assrt1 = all(torch.isclose(cont_x_and_x_[0][0].squeeze(1), targ1.float()))
assrt2 = all(torch.isclose(cont_x_and_x_[1][0].squeeze(1), targ2.float()))
assertions.extend([assrt1, assrt2])
return assertions
def _check_cont_only_multiple_denoise_mlps(X, cont_x_and_x_, with_cls_token):
assertions = []
targ1 = X[:, 1] if with_cls_token else X[:, 0]
targ2 = X[:, 2] if with_cls_token else X[:, 1]
targ1 = X[:, 1] if with_cls_token else X[:, 0]
targ2 = X[:, 2] if with_cls_token else X[:, 1]
assrt9 = all(torch.isclose(cont_x_and_x_[0][0].squeeze(1), targ1.float()))
assrt10 = all(torch.isclose(cont_x_and_x_[1][0].squeeze(1), targ2.float()))
assrt1 = all(torch.isclose(cont_x_and_x_[0][0].squeeze(1), targ1.float()))
assrt2 = all(torch.isclose(cont_x_and_x_[1][0].squeeze(1), targ2.float()))
assertions.extend([assrt9, assrt10])
assertions.extend([assrt1, assrt2])
assert all(assertions)
return assertions
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册