diff --git a/propeller/paddle/train/metrics.py b/propeller/paddle/train/metrics.py index 5603ee2c0a22c45c21ec5ff29e1d65be912663f8..096b3808d0be43928452e91df0f9e23460886d66 100644 --- a/propeller/paddle/train/metrics.py +++ b/propeller/paddle/train/metrics.py @@ -173,8 +173,8 @@ class Precision(Metrics): def eval(self): """doc""" tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() - t = self.label_saver.astype(np.int64).sum() - return tp / t + p = self.pred_saver.astype(np.int64).sum() + return tp / p class Recall(Precision): @@ -183,8 +183,8 @@ class Recall(Precision): def eval(self): """doc""" tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() - p = (self.label_saver).astype(np.int64).sum() - return tp / p + t = (self.label_saver).astype(np.int64).sum() + return tp / t class F1(Precision): @@ -195,8 +195,8 @@ class F1(Precision): tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() t = self.label_saver.astype(np.int64).sum() p = self.pred_saver.astype(np.int64).sum() - precision = tp / (t + 1.e-6) - recall = tp / (p + 1.e-6) + precision = tp / (p + 1.e-6) + recall = tp / (t + 1.e-6) return 2 * precision * recall / (precision + recall + 1.e-6)