提交 2967a4e0 编写于 作者: P Pavol Mulinka

again torchmetrics

上级 e6ac2d35
......@@ -30,18 +30,18 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
@pytest.mark.parametrize(
"metric_name, sklearn_metric, torch_metric",
[
("Accuracy", accuracy_score, Accuracy(task="binary")),
("Precision", precision_score, Precision(task="binary")),
("Recall", recall_score, Recall(task="binary")),
("F1Score", f1_score, F1Score(task="binary")),
("FBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)),
("BinaryAccuracy", accuracy_score, Accuracy(task="binary")),
("BinaryPrecision", precision_score, Precision(task="binary")),
("BinaryRecall", recall_score, Recall(task="binary")),
("BinaryF1Score", f1_score, F1Score(task="binary")),
("BinaryFBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)),
],
)
def test_binary_metrics(metric_name, sklearn_metric, torch_metric):
sk_res = sklearn_metric(y_true_bin_np, y_pred_bin_np.round())
wd_metric = MultipleMetrics(metrics=[torch_metric])
wd_logs = wd_metric(y_pred_bin_pt, y_true_bin_pt)
wd_res = wd_logs[f"Binary{metric_name}"]
wd_res = wd_logs[metric_name]
if wd_res.size != 1:
wd_res = wd_res[1]
assert np.isclose(sk_res, wd_res)
......@@ -77,46 +77,6 @@ def f2_score_multi(y_true, y_pred, average):
@pytest.mark.parametrize(
"metric_name, sklearn_metric, torch_metric",
[
(
"Accuracy",
accuracy_score,
Accuracy(task="multiclass", num_classes=3, average="micro"),
),
(
"Precision",
precision_score,
Precision(task="multiclass", num_classes=3, average="macro"),
),
(
"Recall",
recall_score,
Recall(task="multiclass", num_classes=3, average="macro"),
),
(
"F1Score",
f1_score,
F1Score(task="multiclass", num_classes=3, average="macro"),
),
(
"Accuracy",
accuracy_score,
Accuracy(task="multiclass", num_classes=3, average="micro"),
),
(
"Precision",
precision_score,
Precision(task="multiclass", num_classes=3, average="macro"),
),
(
"Recall",
recall_score,
Recall(task="multiclass", num_classes=3, average="macro"),
),
(
"F1Score",
f1_score,
F1Score(task="multiclass", num_classes=3, average="macro"),
),
(
"MulticlassAccuracy",
accuracy_score,
......@@ -140,7 +100,7 @@ def f2_score_multi(y_true, y_pred, average):
(
"MulticlassFBetaScore",
f2_score_multi,
FBetaScore(task="multiclass", beta=3.0, num_classes=3, average="macro"),
FBetaScore(beta=3.0, task="multiclass", num_classes=3, average="macro"),
),
],
)
......@@ -154,6 +114,6 @@ def test_muticlass_metrics(metric_name, sklearn_metric, torch_metric):
wd_metric = MultipleMetrics(metrics=[torch_metric])
wd_logs = wd_metric(y_pred_multi_pt, y_true_multi_pt)
wd_res = wd_logs[f"Multiclass{metric_name}"]
wd_res = wd_logs[metric_name]
assert np.isclose(sk_res, wd_res, atol=0.01)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册