提交 4c6a2fd8 编写于 作者: J jrzaurin

replaced data() by item() for newer pytorch versions

上级 a94d4b3f
......@@ -5,6 +5,8 @@ from .callbacks import Callback
from .wdtypes import *
import pdb
class Metric(object):
def reset(self):
......@@ -58,9 +60,9 @@ class CategoricalAccuracy(Metric):
self.total_count = 0
def __call__(self, y_pred:Tensor, y_true:Tensor) -> np.ndarray:
top_k = y_pred.topk(self.top_k,1)[1]
top_k = (y_pred.topk(self.top_k,1)[1])
true_k = y_true.view(len(y_true),1).expand_as(top_k)
self.correct_count += top_k.eq(true_k).float().sum().data[0]
self.correct_count += top_k.eq(true_k).float().sum().item()
self.total_count += len(y_pred)
accuracy = float(self.correct_count) / float(self.total_count)
return np.round(accuracy, 4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册