diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index bb9c6fdc60089fc2b43573a6421a6f9781d2d4a8..572475b483ff0341a97a91b6c5309fcf337dacbe 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -325,14 +325,14 @@ class Auc(MetricBase): """ def __init__(self, name, curve='ROC', num_thresholds=200): - super(MetricBase, self).__init__(name, curve, num_thresholds) + super(Auc, self).__init__(name=name) self._curve = curve self._num_thresholds = num_thresholds self._epsilon = 1e-6 - self.tp_list = np.ndarray((num_thresholds, )) - self.fn_list = np.ndarray((num_thresholds, )) - self.tn_list = np.ndarray((num_thresholds, )) - self.fp_list = np.ndarray((num_thresholds, )) + self.tp_list = np.zeros((num_thresholds, )) + self.fn_list = np.zeros((num_thresholds, )) + self.tn_list = np.zeros((num_thresholds, )) + self.fp_list = np.zeros((num_thresholds, )) def update(self, labels, predictions, axis=1): if not _is_numpy_(labels): @@ -350,12 +350,12 @@ class Auc(MetricBase): tp, fn, tn, fp = 0, 0, 0, 0 for i, lbl in enumerate(labels): if lbl: - if predictions[i, 0] >= thresh: + if predictions[i, 1] >= thresh: tp += 1 else: fn += 1 else: - if predictions[i, 0] >= thresh: + if predictions[i, 1] >= thresh: fp += 1 else: tn += 1