提交 7f967740 编写于 作者: P Pavol Mulinka

fixed torchmetrics

上级 c4581337
......@@ -34,14 +34,14 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
("Precision", precision_score, Precision(task="binary")),
("Recall", recall_score, Recall(task="binary")),
("F1Score", f1_score, F1Score(task="binary")),
("FBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2)),
("FBetaScore", f2_score_bin, FBetaScore(task="binary", beta=2.0)),
],
)
def test_binary_metrics(metric_name, sklearn_metric, torch_metric):
sk_res = sklearn_metric(y_true_bin_np, y_pred_bin_np.round())
wd_metric = MultipleMetrics(metrics=[torch_metric])
wd_logs = wd_metric(y_pred_bin_pt, y_true_bin_pt)
wd_res = wd_logs[metric_name]
wd_res = wd_logs[f"Binary{metric_name}"]
if wd_res.size != 1:
wd_res = wd_res[1]
assert np.isclose(sk_res, wd_res)
......@@ -120,7 +120,7 @@ def f2_score_multi(y_true, y_pred, average):
(
"MulticlassFBetaScore",
f2_score_multi,
FBetaScore(task="multiclass", beta=3, num_classes=3, average="macro"),
FBetaScore(task="multiclass", beta=3.0, num_classes=3, average="macro"),
),
],
)
......@@ -134,6 +134,6 @@ def test_muticlass_metrics(metric_name, sklearn_metric, torch_metric):
wd_metric = MultipleMetrics(metrics=[torch_metric])
wd_logs = wd_metric(y_pred_multi_pt, y_true_multi_pt)
wd_res = wd_logs[metric_name]
wd_res = wd_logs[f"Multiclass{metric_name}"]
assert np.isclose(sk_res, wd_res, atol=0.01)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册