From 4aa5da0550f067a58fc32f14c65f794fce3890f2 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 19 Jun 2018 20:19:53 +0800 Subject: [PATCH] fix auc --- python/paddle/fluid/metrics.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index bb9c6fdc600..572475b483f 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -325,14 +325,14 @@ class Auc(MetricBase): """ def __init__(self, name, curve='ROC', num_thresholds=200): - super(MetricBase, self).__init__(name, curve, num_thresholds) + super(Auc, self).__init__(name=name) self._curve = curve self._num_thresholds = num_thresholds self._epsilon = 1e-6 - self.tp_list = np.ndarray((num_thresholds, )) - self.fn_list = np.ndarray((num_thresholds, )) - self.tn_list = np.ndarray((num_thresholds, )) - self.fp_list = np.ndarray((num_thresholds, )) + self.tp_list = np.zeros((num_thresholds, )) + self.fn_list = np.zeros((num_thresholds, )) + self.tn_list = np.zeros((num_thresholds, )) + self.fp_list = np.zeros((num_thresholds, )) def update(self, labels, predictions, axis=1): if not _is_numpy_(labels): @@ -350,12 +350,12 @@ class Auc(MetricBase): tp, fn, tn, fp = 0, 0, 0, 0 for i, lbl in enumerate(labels): if lbl: - if predictions[i, 0] >= thresh: + if predictions[i, 1] >= thresh: tp += 1 else: fn += 1 else: - if predictions[i, 0] >= thresh: + if predictions[i, 1] >= thresh: fp += 1 else: tn += 1 -- GitLab