From 16e8a5f144b98e0d66b57dee90e668500d504bf2 Mon Sep 17 00:00:00 2001 From: Jason N Date: Thu, 9 Apr 2020 11:16:02 +0800 Subject: [PATCH] Update metrics.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改了Precision, Recall, F1中的bug --- propeller/paddle/train/metrics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/propeller/paddle/train/metrics.py b/propeller/paddle/train/metrics.py index 5603ee2..096b380 100644 --- a/propeller/paddle/train/metrics.py +++ b/propeller/paddle/train/metrics.py @@ -173,8 +173,8 @@ class Precision(Metrics): def eval(self): """doc""" tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() - t = self.label_saver.astype(np.int64).sum() - return tp / t + p = self.pred_saver.astype(np.int64).sum() + return tp / p class Recall(Precision): @@ -183,8 +183,8 @@ class Recall(Precision): def eval(self): """doc""" tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() - p = (self.label_saver).astype(np.int64).sum() - return tp / p + t = (self.label_saver).astype(np.int64).sum() + return tp / t class F1(Precision): @@ -195,8 +195,8 @@ class F1(Precision): tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() t = self.label_saver.astype(np.int64).sum() p = self.pred_saver.astype(np.int64).sum() - precision = tp / (t + 1.e-6) - recall = tp / (p + 1.e-6) + precision = tp / (p + 1.e-6) + recall = tp / (t + 1.e-6) return 2 * precision * recall / (precision + recall + 1.e-6) -- GitLab