diff --git a/examples/adult_script.py b/examples/adult_script.py index cd722a3e81f953b4d80863ed9432744e60227501..1d622fc3ba074af3a97e6cf11290b37361603768 100644 --- a/examples/adult_script.py +++ b/examples/adult_script.py @@ -4,7 +4,7 @@ import pandas as pd from pytorch_widedeep.optim import RAdam from pytorch_widedeep.models import Wide, WideDeep, DeepDense -from pytorch_widedeep.metrics import BinaryAccuracy +from pytorch_widedeep.metrics import BinaryAccuracy, Accuracy, Precision from pytorch_widedeep.callbacks import ( LRHistory, EarlyStopping, @@ -76,7 +76,7 @@ if __name__ == "__main__": EarlyStopping, ModelCheckpoint(filepath="model_weights/wd_out"), ] - metrics = [BinaryAccuracy] + metrics = [Precision] model.compile( method="binary", diff --git a/examples/airbnb_script_multiclass.py b/examples/airbnb_script_multiclass.py index 236b454cd100a00c1b1cadab5305a0ed547e5375..eb36085689d94a4ad181ad3f12ba86a6cae943d0 100644 --- a/examples/airbnb_script_multiclass.py +++ b/examples/airbnb_script_multiclass.py @@ -3,7 +3,7 @@ import torch import pandas as pd from pytorch_widedeep.models import Wide, WideDeep, DeepDense -from pytorch_widedeep.metrics import CategoricalAccuracy +from pytorch_widedeep.metrics import CategoricalAccuracy, Accuracy from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor use_cuda = torch.cuda.is_available() @@ -48,7 +48,7 @@ if __name__ == "__main__": continuous_cols=continuous_cols, ) model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3) - model.compile(method="multiclass", metrics=[CategoricalAccuracy]) + model.compile(method="multiclass", metrics=[Accuracy]) model.fit( X_wide=X_wide, diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index 107272a39b2575c7dc025bac9908644e77886b39..c3f111e5d012cec0b133b0581509c4e71e5fe5b1 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -46,6 +46,78 @@ class MetricCallback(Callback): self.container.reset() +class Precision(Metric): + + def __init__(self): + self.true_positives = 0 + self.all_positives = 0 + self.eps = 1e-20 + + self._name = "prec" + + def reset(self) -> None: + self.true_positives = 0 + self.all_positives = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + num_class = y_pred.size(1) + + if num_class == 1: + y_pred = y_pred.round() + y_true = y_true.view(-1, 1) + elif num_class > 1: + y_true = torch.eye(num_class)[y_true.long()] + y_pred = y_pred.topk(1, 1)[1].view(-1) + y_pred = torch.eye(num_class)[y_pred.long()] + + self.true_positives += (y_true * y_pred).sum().item() + self.all_positives += y_pred.sum(dim=0) + + precision = (self.true_positives / (self.all_positives + self.eps)).mean().item() + + return precision + + +class Accuracy(Metric): + r"""Class to calculate the accuracy for both binary and categorical problems + + Parameters + ---------- + top_k: int + Accuracy will be computed using the top k most likely classes in + multiclass problems + """ + + def __init__(self, top_k=1): + self.top_k = top_k + self.correct_count = 0 + self.total_count = 0 + + self._name = "acc" + + def reset(self): + """ + resets counters to 0 + """ + self.correct_count = 0 + self.total_count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + num_classes = y_pred.size(1) + + if num_classes == 1: + y_pred_round = y_pred.round() + self.correct_count += y_pred_round.eq(y_true.view(-1, 1)).sum().item() + elif num_classes > 1: + top_k = y_pred.topk(self.top_k, 1)[1] + true_k = y_true.view(-1, 1).expand_as(top_k) # type: ignore + self.correct_count += top_k.eq(true_k).sum().item() + + self.total_count += len(y_pred) # type: ignore + accuracy = float(self.correct_count) / float(self.total_count) + return accuracy + + class CategoricalAccuracy(Metric): r"""Class to calculate the categorical accuracy for multicategorical problems