提交 453220f6 编写于 作者: J jrzaurin

added precision. Need to test the multiclass case. For binary seems to work fine. More test needed

上级 31c2d8ef
......@@ -4,7 +4,7 @@ import pandas as pd
from pytorch_widedeep.optim import RAdam
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.metrics import BinaryAccuracy
from pytorch_widedeep.metrics import BinaryAccuracy, Accuracy, Precision
from pytorch_widedeep.callbacks import (
LRHistory,
EarlyStopping,
......@@ -76,7 +76,7 @@ if __name__ == "__main__":
EarlyStopping,
ModelCheckpoint(filepath="model_weights/wd_out"),
]
metrics = [BinaryAccuracy]
metrics = [Precision]
model.compile(
method="binary",
......
......@@ -3,7 +3,7 @@ import torch
import pandas as pd
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.metrics import CategoricalAccuracy
from pytorch_widedeep.metrics import CategoricalAccuracy, Accuracy
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
use_cuda = torch.cuda.is_available()
......@@ -48,7 +48,7 @@ if __name__ == "__main__":
continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3)
model.compile(method="multiclass", metrics=[CategoricalAccuracy])
model.compile(method="multiclass", metrics=[Accuracy])
model.fit(
X_wide=X_wide,
......
......@@ -46,6 +46,78 @@ class MetricCallback(Callback):
self.container.reset()
class Precision(Metric):
def __init__(self):
self.true_positives = 0
self.all_positives = 0
self.eps = 1e-20
self._name = "prec"
def reset(self) -> None:
self.true_positives = 0
self.all_positives = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
num_class = y_pred.size(1)
if num_class == 1:
y_pred = y_pred.round()
y_true = y_true.view(-1, 1)
elif num_class > 1:
y_true = torch.eye(num_class)[y_true.long()]
y_pred = y_pred.topk(1, 1)[1].view(-1)
y_pred = torch.eye(num_class)[y_pred.long()]
self.true_positives += (y_true * y_pred).sum().item()
self.all_positives += y_pred.sum(dim=0)
precision = (self.true_positives / (self.all_positives + self.eps)).mean().item()
return precision
class Accuracy(Metric):
r"""Class to calculate the accuracy for both binary and categorical problems
Parameters
----------
top_k: int
Accuracy will be computed using the top k most likely classes in
multiclass problems
"""
def __init__(self, top_k=1):
self.top_k = top_k
self.correct_count = 0
self.total_count = 0
self._name = "acc"
def reset(self):
"""
resets counters to 0
"""
self.correct_count = 0
self.total_count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
num_classes = y_pred.size(1)
if num_classes == 1:
y_pred_round = y_pred.round()
self.correct_count += y_pred_round.eq(y_true.view(-1, 1)).sum().item()
elif num_classes > 1:
top_k = y_pred.topk(self.top_k, 1)[1]
true_k = y_true.view(-1, 1).expand_as(top_k) # type: ignore
self.correct_count += top_k.eq(true_k).sum().item()
self.total_count += len(y_pred) # type: ignore
accuracy = float(self.correct_count) / float(self.total_count)
return accuracy
class CategoricalAccuracy(Metric):
r"""Class to calculate the categorical accuracy for multicategorical problems
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册