diff --git a/tests/test_metrics/test_torchmetrics.py b/tests/test_metrics/test_torchmetrics.py index 545cd4885bebd62d420e32ce1153565861afca44..8909a22aa20d9f2f05c61595e637257281d996e6 100644 --- a/tests/test_metrics/test_torchmetrics.py +++ b/tests/test_metrics/test_torchmetrics.py @@ -77,6 +77,26 @@ def f2_score_multi(y_true, y_pred, average): @pytest.mark.parametrize( "metric_name, sklearn_metric, torch_metric", [ + ( + "Accuracy", + accuracy_score, + Accuracy(task="multiclass", num_classes=3, average="micro"), + ), + ( + "Precision", + precision_score, + Precision(task="multiclass", num_classes=3, average="macro"), + ), + ( + "Recall", + recall_score, + Recall(task="multiclass", num_classes=3, average="macro"), + ), + ( + "F1Score", + f1_score, + F1Score(task="multiclass", num_classes=3, average="macro"), + ), ( "Accuracy", accuracy_score,