提交 f1926c6b 编写于 作者: P Pavol Mulinka

changed multiregression to multilabel for future use of multilabel task

上级 34ad5576
...@@ -57,6 +57,7 @@ class QuantileLoss(nn.Module): ...@@ -57,6 +57,7 @@ class QuantileLoss(nn.Module):
assert input.shape == torch.Size( assert input.shape == torch.Size(
[target.shape[0], len(self.quantiles)] [target.shape[0], len(self.quantiles)]
), f"Wrong shape of input, pred_dim of the model that is using QuantileLoss must be equal to number of quantiles, i.e. {len(self.quantiles)}." ), f"Wrong shape of input, pred_dim of the model that is using QuantileLoss must be equal to number of quantiles, i.e. {len(self.quantiles)}."
target = target.view(-1, 1).float()
losses = [] losses = []
for i, q in enumerate(self.quantiles): for i, q in enumerate(self.quantiles):
errors = target - input[..., i] errors = target - input[..., i]
...@@ -80,7 +81,7 @@ class ZILNLoss(nn.Module): ...@@ -80,7 +81,7 @@ class ZILNLoss(nn.Module):
Parameters Parameters
---------- ----------
input: Tensor input: Tensor
input tensor with predictions (not probabilities) input tensor with predictions (not probabilities) with spape (N,3), where N is the batch size
target: Tensor target: Tensor
target tensor with the actual classes target tensor with the actual classes
......
...@@ -75,7 +75,7 @@ class _ObjectiveToMethod: ...@@ -75,7 +75,7 @@ class _ObjectiveToMethod:
"zero_inflated_lognormal": "regression", "zero_inflated_lognormal": "regression",
"ziln": "regression", "ziln": "regression",
"tweedie": "regression", "tweedie": "regression",
"quantile": "multiregression", "quantile": "multilabel",
} }
@classproperty @classproperty
......
...@@ -97,7 +97,7 @@ class Trainer: ...@@ -97,7 +97,7 @@ class Trainer:
folder in the repo. folder in the repo.
.. note:: If ``custom_loss_function`` is not None, ``objective`` must be .. note:: If ``custom_loss_function`` is not None, ``objective`` must be
'binary', 'multiclass', 'multiregression' or 'regression', consistent with the loss 'binary', 'multiclass', 'multilabel' or 'regression', consistent with the loss
function function
optimizers: ``Optimzer`` or dict, optional, default= None optimizers: ``Optimzer`` or dict, optional, default= None
...@@ -265,11 +265,11 @@ class Trainer: ...@@ -265,11 +265,11 @@ class Trainer:
"binary", "binary",
"multiclass", "multiclass",
"regression", "regression",
"multiregression", "multilabel",
]: ]:
raise ValueError( raise ValueError(
"If 'custom_loss_function' is not None, 'objective' must be 'binary' " "If 'custom_loss_function' is not None, 'objective' must be 'binary' "
"'multiclass', 'regression' or 'multiregression', consistent with the loss function" "'multiclass', 'regression' or 'multilabel', consistent with the loss function"
) )
self.reducelronplateau = False self.reducelronplateau = False
...@@ -705,7 +705,7 @@ class Trainer: ...@@ -705,7 +705,7 @@ class Trainer:
if self.method == "binary": if self.method == "binary":
preds = np.vstack(preds_l).squeeze(1) preds = np.vstack(preds_l).squeeze(1)
return (preds > 0.5).astype("int") return (preds > 0.5).astype("int")
if self.method == "multiregression": if self.method == "multilabel":
return np.vstack(preds_l) return np.vstack(preds_l)
if self.method == "multiclass": if self.method == "multiclass":
preds = np.vstack(preds_l) preds = np.vstack(preds_l)
...@@ -802,9 +802,9 @@ class Trainer: ...@@ -802,9 +802,9 @@ class Trainer:
preds.std(axis=0), preds.std(axis=0),
) )
).T ).T
if self.method == "multiregression": if self.method == "multilabel":
raise ValueError( raise ValueError(
"Currently predict_uncertainty is not supported for multiregression method" "Currently predict_uncertainty is not supported for multilabel method"
) )
if self.method == "binary": if self.method == "binary":
preds = preds.squeeze(1) preds = preds.squeeze(1)
...@@ -1154,7 +1154,11 @@ class Trainer: ...@@ -1154,7 +1154,11 @@ class Trainer:
def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int): def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
self.model.train() self.model.train()
X = {k: v.cuda() for k, v in data.items()} if use_cuda else data X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
y = target.view(-1, 1).float() if self.method != "multiclass" else target y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "multilabel"]
else target
)
y = y.to(device) y = y.to(device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
...@@ -1179,7 +1183,11 @@ class Trainer: ...@@ -1179,7 +1183,11 @@ class Trainer:
self.model.eval() self.model.eval()
with torch.no_grad(): with torch.no_grad():
X = {k: v.cuda() for k, v in data.items()} if use_cuda else data X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
y = target.view(-1, 1).float() if self.method != "multiclass" else target y = (
target.view(-1, 1).float()
if self.method not in ["multiclass", "multilabel"]
else target
)
y = y.to(device) y = y.to(device)
y_pred = self.model(X) y_pred = self.model(X)
...@@ -1201,7 +1209,7 @@ class Trainer: ...@@ -1201,7 +1209,7 @@ class Trainer:
score = self.metric(y_pred, y) score = self.metric(y_pred, y)
if self.method == "binary": if self.method == "binary":
score = self.metric(torch.sigmoid(y_pred), y) score = self.metric(torch.sigmoid(y_pred), y)
if self.method == "multiregression": if self.method == "multilabel":
score = self.metric(y_pred, y) score = self.metric(y_pred, y)
if self.method == "multiclass": if self.method == "multiclass":
score = self.metric(F.softmax(y_pred, dim=1), y) score = self.metric(F.softmax(y_pred, dim=1), y)
...@@ -1318,7 +1326,7 @@ class Trainer: ...@@ -1318,7 +1326,7 @@ class Trainer:
if custom_loss_function is not None: if custom_loss_function is not None:
return custom_loss_function return custom_loss_function
elif ( elif (
self.method not in ["regression", "multiregression"] self.method not in ["regression", "multilabel"]
and "focal_loss" not in objective and "focal_loss" not in objective
): ):
return alias_to_loss(objective, weight=class_weight) return alias_to_loss(objective, weight=class_weight)
......
...@@ -161,7 +161,7 @@ method_to_objec = { ...@@ -161,7 +161,7 @@ method_to_objec = {
"ziln", "ziln",
"tweedie", "tweedie",
], ],
"multiregression": [ "multilabel": [
"quantile", "quantile",
], ],
} }
...@@ -230,7 +230,7 @@ method_to_objec = { ...@@ -230,7 +230,7 @@ method_to_objec = {
False, False,
), ),
(X_wide, X_tab, target_regres, "regression", "ziln", 3, 1, False), (X_wide, X_tab, target_regres, "regression", "ziln", 3, 1, False),
(X_wide, X_tab, target_regres, "multiregression", "quantile", 7, 7, False), (X_wide, X_tab, target_regres, "multilabel", "quantile", 7, 7, False),
(X_wide, X_tab, target_regres, "regression", "tweedie", 1, 1, True), (X_wide, X_tab, target_regres, "regression", "tweedie", 1, 1, True),
(X_wide, X_tab, target_binary, "binary", "binary", 1, 2, False), (X_wide, X_tab, target_binary, "binary", "binary", 1, 2, False),
(X_wide, X_tab, target_binary, "binary", "logistic", 1, 2, False), (X_wide, X_tab, target_binary, "binary", "logistic", 1, 2, False),
...@@ -292,7 +292,7 @@ def test_all_possible_objectives( ...@@ -292,7 +292,7 @@ def test_all_possible_objectives(
if method == "regression": if method == "regression":
preds = trainer.predict(X_wide=X_wide, X_tab=X_tab) preds = trainer.predict(X_wide=X_wide, X_tab=X_tab)
out.append(preds.ndim == probs_dim) out.append(preds.ndim == probs_dim)
elif method == "multiregression": elif method == "multilabel":
preds = trainer.predict(X_wide=X_wide, X_tab=X_tab) preds = trainer.predict(X_wide=X_wide, X_tab=X_tab)
out.append(preds.shape[1] == probs_dim) out.append(preds.shape[1] == probs_dim)
else: else:
...@@ -323,6 +323,6 @@ def test_inverse_maps(): ...@@ -323,6 +323,6 @@ def test_inverse_maps():
out.append( out.append(
"zero_inflated_lognormal" in _ObjectiveToMethod.method_to_objecive["regression"] "zero_inflated_lognormal" in _ObjectiveToMethod.method_to_objecive["regression"]
) )
out.append("quantile" in _ObjectiveToMethod.method_to_objecive["multiregression"]) out.append("quantile" in _ObjectiveToMethod.method_to_objecive["multilabel"])
out.append("tweedie" in _ObjectiveToMethod.method_to_objecive["regression"]) out.append("tweedie" in _ObjectiveToMethod.method_to_objecive["regression"])
assert all(out) assert all(out)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册