提交 2a6ceaaa 编写于 作者: P Pavol Mulinka

fixed torchmetrics test

上级 7f967740
...@@ -34,7 +34,7 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np) ...@@ -34,7 +34,7 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
("Precision", precision_score, Precision(task="binary")), ("Precision", precision_score, Precision(task="binary")),
("Recall", recall_score, Recall(task="binary")), ("Recall", recall_score, Recall(task="binary")),
("F1Score", f1_score, F1Score(task="binary")), ("F1Score", f1_score, F1Score(task="binary")),
("FBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)), ("FBetaScore", f2_score_bin, FBetaScore(beta=2)),
], ],
) )
def test_binary_metrics(metric_name, sklearn_metric, torch_metric): def test_binary_metrics(metric_name, sklearn_metric, torch_metric):
...@@ -77,6 +77,10 @@ def f2_score_multi(y_true, y_pred, average): ...@@ -77,6 +77,10 @@ def f2_score_multi(y_true, y_pred, average):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"metric_name, sklearn_metric, torch_metric", "metric_name, sklearn_metric, torch_metric",
[ [
("Accuracy", accuracy_score, Accuracy(task="multiclass", num_classes=3, average="micro")),
("Precision", precision_score, Precision(task="multiclass", num_classes=3, average="macro")),
("Recall", recall_score, Recall(task="multiclass", num_classes=3, average="macro")),
("F1Score", f1_score, F1Score(task="multiclass", num_classes=3, average="macro")),
( (
"Accuracy", "Accuracy",
accuracy_score, accuracy_score,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册