diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index eea8a51bf55f0a0d9905c466ac8812cd458cd533..2cc7df34f89756d61b9dc3e63e919752ecce39c6 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -62,7 +62,7 @@ class CategoricalAccuracy(Metric): def __call__(self, y_pred:Tensor, y_true:Tensor) -> np.ndarray: top_k = (y_pred.topk(self.top_k,1)[1]) - true_k = y_true.view(len(y_true),1).expand_as(top_k) # type: ignore (ignore len vs .size()) + true_k = y_true.view(len(y_true),1).expand_as(top_k) # type: ignore self.correct_count += top_k.eq(true_k).float().sum().item() self.total_count += len(y_pred) # type: ignore accuracy = float(self.correct_count) / float(self.total_count)