diff --git a/tests/test_metrics/test_torchmetrics.py b/tests/test_metrics/test_torchmetrics.py index 8909a22aa20d9f2f05c61595e637257281d996e6..3213229ac53a2d296ec60bf2b1810504162e2571 100644 --- a/tests/test_metrics/test_torchmetrics.py +++ b/tests/test_metrics/test_torchmetrics.py @@ -30,18 +30,18 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np) @pytest.mark.parametrize( "metric_name, sklearn_metric, torch_metric", [ - ("Accuracy", accuracy_score, Accuracy(task="binary")), - ("Precision", precision_score, Precision(task="binary")), - ("Recall", recall_score, Recall(task="binary")), - ("F1Score", f1_score, F1Score(task="binary")), - ("FBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)), + ("BinaryAccuracy", accuracy_score, Accuracy(task="binary")), + ("BinaryPrecision", precision_score, Precision(task="binary")), + ("BinaryRecall", recall_score, Recall(task="binary")), + ("BinaryF1Score", f1_score, F1Score(task="binary")), + ("BinaryFBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)), ], ) def test_binary_metrics(metric_name, sklearn_metric, torch_metric): sk_res = sklearn_metric(y_true_bin_np, y_pred_bin_np.round()) wd_metric = MultipleMetrics(metrics=[torch_metric]) wd_logs = wd_metric(y_pred_bin_pt, y_true_bin_pt) - wd_res = wd_logs[f"Binary{metric_name}"] + wd_res = wd_logs[metric_name] if wd_res.size != 1: wd_res = wd_res[1] assert np.isclose(sk_res, wd_res) @@ -77,46 +77,6 @@ def f2_score_multi(y_true, y_pred, average): @pytest.mark.parametrize( "metric_name, sklearn_metric, torch_metric", [ - ( - "Accuracy", - accuracy_score, - Accuracy(task="multiclass", num_classes=3, average="micro"), - ), - ( - "Precision", - precision_score, - Precision(task="multiclass", num_classes=3, average="macro"), - ), - ( - "Recall", - recall_score, - Recall(task="multiclass", num_classes=3, average="macro"), - ), - ( - "F1Score", - f1_score, - F1Score(task="multiclass", num_classes=3, average="macro"), - ), - ( - "Accuracy", - accuracy_score, - Accuracy(task="multiclass", num_classes=3, average="micro"), - ), - ( - "Precision", - precision_score, - Precision(task="multiclass", num_classes=3, average="macro"), - ), - ( - "Recall", - recall_score, - Recall(task="multiclass", num_classes=3, average="macro"), - ), - ( - "F1Score", - f1_score, - F1Score(task="multiclass", num_classes=3, average="macro"), - ), ( "MulticlassAccuracy", accuracy_score, @@ -140,7 +100,7 @@ def f2_score_multi(y_true, y_pred, average): ( "MulticlassFBetaScore", f2_score_multi, - FBetaScore(task="multiclass", beta=3.0, num_classes=3, average="macro"), + FBetaScore(beta=3.0, task="multiclass", num_classes=3, average="macro"), ), ], ) @@ -154,6 +114,6 @@ def test_muticlass_metrics(metric_name, sklearn_metric, torch_metric): wd_metric = MultipleMetrics(metrics=[torch_metric]) wd_logs = wd_metric(y_pred_multi_pt, y_true_multi_pt) - wd_res = wd_logs[f"Multiclass{metric_name}"] + wd_res = wd_logs[metric_name] assert np.isclose(sk_res, wd_res, atol=0.01)