提交 d13e86f9 编写于 作者: P Pavol Mulinka

fixed multiclass torchmetrics

上级 ef9ea277
此差异已折叠。
import numpy as np
import torch
from torchmetrics import Metric as TorchMetric
from torchmetrics import AUC
from .wdtypes import * # noqa: F403
......@@ -38,10 +39,23 @@ class MultipleMetrics(object):
if isinstance(metric, Metric):
logs[self.prefix + metric._name] = metric(y_pred, y_true)
if isinstance(metric, TorchMetric):
if not hasattr(metric, "num_classes"):
raise ValueError(
"""TorchMetric does not have num_classes attribute.
Use metric in this library or extend the metric by num_classes attribute,
see `examples <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`
"""
)
if metric.num_classes == 2:
metric.update(torch.round(y_pred).int(), y_true.int())
if isinstance(metric, AUC):
metric.update(torch.round(y_pred).int(), y_true.int())
else:
metric.update(y_pred, y_true.int())
if metric.num_classes > 2: # type: ignore[operator]
metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined]
if isinstance(metric, AUC):
metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined]
else:
metric.update(y_pred, y_true.int()) # type: ignore[attr-defined]
logs[self.prefix + type(metric).__name__] = (
metric.compute().detach().cpu().numpy()
)
......@@ -396,3 +410,62 @@ class R2Score(Metric):
y_true_avg = self.y_true_sum / self.num_examples
self.denominator += ((y_true - y_true_avg) ** 2).sum().item()
return np.array((1 - (self.numerator / self.denominator)))
class Accuracy(Metric):
r"""Class to calculate the accuracy for both binary and categorical problems
Parameters
----------
top_k: int, default = 1
Accuracy will be computed using the top k most likely classes in
multiclass problems
Examples
--------
>>> import torch
>>>
>>> from pytorch_widedeep.metrics import Accuracy
>>>
>>> acc = Accuracy()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> acc(y_pred, y_true)
array(0.5)
>>>
>>> acc = Accuracy(top_k=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> acc(y_pred, y_true)
array(0.66666667)
"""
def __init__(self, top_k: int = 1):
super(Accuracy, self).__init__()
self.top_k = top_k
self.correct_count = 0
self.total_count = 0
self._name = "acc"
def reset(self):
"""
resets counters to 0
"""
self.correct_count = 0
self.total_count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
num_classes = y_pred.size(1)
if num_classes == 1:
y_pred = y_pred.round()
y_true = y_true
elif num_classes > 1:
y_pred = y_pred.topk(self.top_k, 1)[1]
y_true = y_true.view(-1, 1).expand_as(y_pred)
self.correct_count += y_pred.eq(y_true).sum().item() # type: ignore[assignment]
self.total_count += len(y_pred)
accuracy = float(self.correct_count) / float(self.total_count)
return np.array(accuracy)
......@@ -147,10 +147,14 @@ class Trainer:
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
- List of objects of type :obj:`torchmetrics.Metric`. This can be any
metric from torchmetrics library `Examples
metric from torchmetrics library that has attribute num_classes `Examples
<https://torchmetrics.readthedocs.io/en/latest/references/modules.html#
classification-metrics>`_. This can also be a custom metric as
long as it is an object of type :obj:`Metric`. See `the instructions
classification-metrics>`_.
Objects of type :obj:`torchmetrics.Metric` can be extended with num_classes
attribute to be used with the Trainer object, see `examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`.
This can also be a custom metric as long as it is an object of
type :obj:`Metric`. See `the instructions
<https://torchmetrics.readthedocs.io/en/latest/>`_.
class_weight: float, List or Tuple. optional. default=None
- float indicating the weight of the minority class in binary classification
......
import numpy as np
import torch
import pytest
from torchmetrics import F1, FBeta, Recall, Accuracy, Precision
from torchmetrics import F1, FBeta, Recall, Accuracy, Precision, AUC
from sklearn.metrics import (
f1_score,
fbeta_score,
recall_score,
accuracy_score,
precision_score,
auc_score,
)
from pytorch_widedeep.metrics import MultipleMetrics
......@@ -35,9 +36,12 @@ y_pred_bin_pt = torch.from_numpy(y_pred_bin_np)
("Recall", recall_score, Recall(num_classes=2, average="none")),
("F1", f1_score, F1(num_classes=2, average="none")),
("FBeta", f2_score_bin, FBeta(beta=2, num_classes=2, average="none")),
("AUC", auc_score, AUC()),
],
)
def test_binary_metrics(metric_name, sklearn_metric, torch_metric):
if metric_name == "AUC":
torch_metric.num_classes=2
sk_res = sklearn_metric(y_true_bin_np, y_pred_bin_np.round())
wd_metric = MultipleMetrics(metrics=[torch_metric])
wd_logs = wd_metric(y_pred_bin_pt, y_true_bin_pt)
......@@ -82,11 +86,14 @@ def f2_score_multi(y_true, y_pred, average):
("Recall", recall_score, Recall(num_classes=3, average="macro")),
("F1", f1_score, F1(num_classes=3, average="macro")),
("FBeta", f2_score_multi, FBeta(beta=3, num_classes=3, average="macro")),
("AUC", auc_score, AUC()),
],
)
def test_muticlass_metrics(metric_name, sklearn_metric, torch_metric):
if metric_name == "Accuracy":
sk_res = sklearn_metric(y_true_multi_np, y_pred_muli_np.argmax(axis=1))
elif metric_name == "AUC":
torch_metric.num_classes=3
else:
sk_res = sklearn_metric(
y_true_multi_np, y_pred_muli_np.argmax(axis=1), average="macro"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册