提交 f1926c6b 编写于 作者: P Pavol Mulinka

changed multiregression to multilabel for future use of multilabel task

上级 34ad5576
......@@ -57,6 +57,7 @@ class QuantileLoss(nn.Module):
assert input.shape == torch.Size(
[target.shape[0], len(self.quantiles)]
), f"Wrong shape of input, pred_dim of the model that is using QuantileLoss must be equal to number of quantiles, i.e. {len(self.quantiles)}."
target = target.view(-1, 1).float()
losses = []
for i, q in enumerate(self.quantiles):
errors = target - input[..., i]
......@@ -80,7 +81,7 @@ class ZILNLoss(nn.Module):
Parameters
----------
input: Tensor
input tensor with predictions (not probabilities)
input tensor with predictions (not probabilities) with spape (N,3), where N is the batch size
target: Tensor
target tensor with the actual classes
......
......@@ -75,7 +75,7 @@ class _ObjectiveToMethod:
"zero_inflated_lognormal": "regression",
"ziln": "regression",
"tweedie": "regression",
"quantile": "multiregression",
"quantile": "multilabel",
}
@classproperty
......
......@@ -97,7 +97,7 @@ class Trainer:
folder in the repo.
.. note:: If ``custom_loss_function`` is not None, ``objective`` must be
'binary', 'multiclass', 'multiregression' or 'regression', consistent with the loss
'binary', 'multiclass', 'multilabel' or 'regression', consistent with the loss
function
optimizers: ``Optimzer`` or dict, optional, default= None
......@@ -265,11 +265,11 @@ class Trainer:
"binary",
"multiclass",
"regression",
"multiregression",
"multilabel",
]:
raise ValueError(
"If 'custom_loss_function' is not None, 'objective' must be 'binary' "
"'multiclass', 'regression' or 'multiregression', consistent with the loss function"
"'multiclass', 'regression' or 'multilabel', consistent with the loss function"
)
self.reducelronplateau = False
......@@ -705,7 +705,7 @@ class Trainer:
if self.method == "binary":
preds = np.vstack(preds_l).squeeze(1)
return (preds > 0.5).astype("int")
if self.method == "multiregression":
if self.method == "multilabel":
return np.vstack(preds_l)
if self.method == "multiclass":
preds = np.vstack(preds_l)
......@@ -802,9 +802,9 @@ class Trainer:
preds.std(axis=0),
)
).T
if self.method == "multiregression":
if self.method == "multilabel":
raise ValueError(
"Currently predict_uncertainty is not supported for multiregression method"
"Currently predict_uncertainty is not supported for multilabel method"
)
if self.method == "binary":
preds = preds.squeeze(1)
......@@ -1154,7 +1154,11 @@ class Trainer:
def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
self.model.train()
X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
y = target.view(-1, 1).float() if self.method != "multiclass" else target
y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "multilabel"]
else target
)
y = y.to(device)
self.optimizer.zero_grad()
......@@ -1179,7 +1183,11 @@ class Trainer:
self.model.eval()
with torch.no_grad():
X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
y = target.view(-1, 1).float() if self.method != "multiclass" else target
y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "multilabel"]
else target
)
y = y.to(device)
y_pred = self.model(X)
......@@ -1201,7 +1209,7 @@ class Trainer:
score = self.metric(y_pred, y)
if self.method == "binary":
score = self.metric(torch.sigmoid(y_pred), y)
if self.method == "multiregression":
if self.method == "multilabel":
score = self.metric(y_pred, y)
if self.method == "multiclass":
score = self.metric(F.softmax(y_pred, dim=1), y)
......@@ -1318,7 +1326,7 @@ class Trainer:
if custom_loss_function is not None:
return custom_loss_function
elif (
self.method not in ["regression", "multiregression"]
self.method not in ["regression", "multilabel"]
and "focal_loss" not in objective
):
return alias_to_loss(objective, weight=class_weight)
......
......@@ -161,7 +161,7 @@ method_to_objec = {
"ziln",
"tweedie",
],
"multiregression": [
"multilabel": [
"quantile",
],
}
......@@ -230,7 +230,7 @@ method_to_objec = {
False,
),
(X_wide, X_tab, target_regres, "regression", "ziln", 3, 1, False),
(X_wide, X_tab, target_regres, "multiregression", "quantile", 7, 7, False),
(X_wide, X_tab, target_regres, "multilabel", "quantile", 7, 7, False),
(X_wide, X_tab, target_regres, "regression", "tweedie", 1, 1, True),
(X_wide, X_tab, target_binary, "binary", "binary", 1, 2, False),
(X_wide, X_tab, target_binary, "binary", "logistic", 1, 2, False),
......@@ -292,7 +292,7 @@ def test_all_possible_objectives(
if method == "regression":
preds = trainer.predict(X_wide=X_wide, X_tab=X_tab)
out.append(preds.ndim == probs_dim)
elif method == "multiregression":
elif method == "multilabel":
preds = trainer.predict(X_wide=X_wide, X_tab=X_tab)
out.append(preds.shape[1] == probs_dim)
else:
......@@ -323,6 +323,6 @@ def test_inverse_maps():
out.append(
"zero_inflated_lognormal" in _ObjectiveToMethod.method_to_objecive["regression"]
)
out.append("quantile" in _ObjectiveToMethod.method_to_objecive["multiregression"])
out.append("quantile" in _ObjectiveToMethod.method_to_objecive["multilabel"])
out.append("tweedie" in _ObjectiveToMethod.method_to_objecive["regression"])
assert all(out)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册