提交 2a6ceaaa 编写于 作者: P Pavol Mulinka

fixed torchmetrics test

上级 7f967740
......@@ -34,7 +34,7 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
("Precision", precision_score, Precision(task="binary")),
("Recall", recall_score, Recall(task="binary")),
("F1Score", f1_score, F1Score(task="binary")),
("FBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)),
("FBetaScore", f2_score_bin, FBetaScore(beta=2)),
],
)
def test_binary_metrics(metric_name, sklearn_metric, torch_metric):
......@@ -77,6 +77,10 @@ def f2_score_multi(y_true, y_pred, average):
@pytest.mark.parametrize(
"metric_name, sklearn_metric, torch_metric",
[
("Accuracy", accuracy_score, Accuracy(task="multiclass", num_classes=3, average="micro")),
("Precision", precision_score, Precision(task="multiclass", num_classes=3, average="macro")),
("Recall", recall_score, Recall(task="multiclass", num_classes=3, average="macro")),
("F1Score", f1_score, F1Score(task="multiclass", num_classes=3, average="macro")),
(
"Accuracy",
accuracy_score,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册