diff --git a/tests/test_metrics/test_torchmetrics.py b/tests/test_metrics/test_torchmetrics.py index 545cd4885bebd62d420e32ce1153565861afca44..063d699f9cc4a5ae548ce8f8524dba5029fe375e 100644 --- a/tests/test_metrics/test_torchmetrics.py +++ b/tests/test_metrics/test_torchmetrics.py @@ -34,7 +34,7 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np) ("Precision", precision_score, Precision(task="binary")), ("Recall", recall_score, Recall(task="binary")), ("F1Score", f1_score, F1Score(task="binary")), - ("FBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)), + ("FBetaScore", f2_score_bin, FBetaScore(beta=2)), ], ) def test_binary_metrics(metric_name, sklearn_metric, torch_metric): @@ -77,6 +77,10 @@ def f2_score_multi(y_true, y_pred, average): @pytest.mark.parametrize( "metric_name, sklearn_metric, torch_metric", [ + ("Accuracy", accuracy_score, Accuracy(task="multiclass", num_classes=3, average="micro")), + ("Precision", precision_score, Precision(task="multiclass", num_classes=3, average="macro")), + ("Recall", recall_score, Recall(task="multiclass", num_classes=3, average="macro")), + ("F1Score", f1_score, F1Score(task="multiclass", num_classes=3, average="macro")), ( "Accuracy", accuracy_score,