From 2a6ceaaac11297593e86ad1f06b27efed9efa2cc Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Sun, 8 Jan 2023 18:13:28 +0100 Subject: [PATCH] fixed torchmetrics test --- tests/test_metrics/test_torchmetrics.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_metrics/test_torchmetrics.py b/tests/test_metrics/test_torchmetrics.py index 545cd48..063d699 100644 --- a/tests/test_metrics/test_torchmetrics.py +++ b/tests/test_metrics/test_torchmetrics.py @@ -34,7 +34,7 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np) ("Precision", precision_score, Precision(task="binary")), ("Recall", recall_score, Recall(task="binary")), ("F1Score", f1_score, F1Score(task="binary")), - ("FBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)), + ("FBetaScore", f2_score_bin, FBetaScore(beta=2)), ], ) def test_binary_metrics(metric_name, sklearn_metric, torch_metric): @@ -77,6 +77,10 @@ def f2_score_multi(y_true, y_pred, average): @pytest.mark.parametrize( "metric_name, sklearn_metric, torch_metric", [ + ("Accuracy", accuracy_score, Accuracy(task="multiclass", num_classes=3, average="micro")), + ("Precision", precision_score, Precision(task="multiclass", num_classes=3, average="macro")), + ("Recall", recall_score, Recall(task="multiclass", num_classes=3, average="macro")), + ("F1Score", f1_score, F1Score(task="multiclass", num_classes=3, average="macro")), ( "Accuracy", accuracy_score, -- GitLab