提交 2dcd6886 编写于 作者: P Pavol Mulinka

fixed torchmetrics test

上级 f075e2c3
......@@ -30,10 +30,10 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
@pytest.mark.parametrize(
"metric_name, sklearn_metric, torch_metric",
[
("Accuracy", accuracy_score, Accuracy()),
("Precision", precision_score, Precision()),
("Recall", recall_score, Recall()),
("F1Score", f1_score, F1Score()),
("Accuracy", accuracy_score, Accuracy(task="binary")),
("Precision", precision_score, Precision(task="binary")),
("Recall", recall_score, Recall(task="binary")),
("F1Score", f1_score, F1Score(task="binary")),
("FBetaScore", f2_score_bin, FBetaScore(beta=2)),
],
)
......@@ -77,10 +77,10 @@ def f2_score_multi(y_true, y_pred, average):
@pytest.mark.parametrize(
"metric_name, sklearn_metric, torch_metric",
[
("Accuracy", accuracy_score, Accuracy(num_classes=3, average="micro")),
("Precision", precision_score, Precision(num_classes=3, average="macro")),
("Recall", recall_score, Recall(num_classes=3, average="macro")),
("F1Score", f1_score, F1Score(num_classes=3, average="macro")),
("Accuracy", accuracy_score, Accuracy(task="multiclass", num_classes=3, average="micro")),
("Precision", precision_score, Precision(task="multiclass", num_classes=3, average="macro")),
("Recall", recall_score, Recall(task="multiclass", num_classes=3, average="macro")),
("F1Score", f1_score, F1Score(task="multiclass", num_classes=3, average="macro")),
(
"FBetaScore",
f2_score_multi,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册