未验证 提交 1d4d8de0 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #11574 from jacquesqiao/fix-auc

fix auc
...@@ -325,14 +325,14 @@ class Auc(MetricBase): ...@@ -325,14 +325,14 @@ class Auc(MetricBase):
""" """
def __init__(self, name, curve='ROC', num_thresholds=200): def __init__(self, name, curve='ROC', num_thresholds=200):
super(MetricBase, self).__init__(name, curve, num_thresholds) super(Auc, self).__init__(name=name)
self._curve = curve self._curve = curve
self._num_thresholds = num_thresholds self._num_thresholds = num_thresholds
self._epsilon = 1e-6 self._epsilon = 1e-6
self.tp_list = np.ndarray((num_thresholds, )) self.tp_list = np.zeros((num_thresholds, ))
self.fn_list = np.ndarray((num_thresholds, )) self.fn_list = np.zeros((num_thresholds, ))
self.tn_list = np.ndarray((num_thresholds, )) self.tn_list = np.zeros((num_thresholds, ))
self.fp_list = np.ndarray((num_thresholds, )) self.fp_list = np.zeros((num_thresholds, ))
def update(self, labels, predictions, axis=1): def update(self, labels, predictions, axis=1):
if not _is_numpy_(labels): if not _is_numpy_(labels):
...@@ -350,12 +350,12 @@ class Auc(MetricBase): ...@@ -350,12 +350,12 @@ class Auc(MetricBase):
tp, fn, tn, fp = 0, 0, 0, 0 tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels): for i, lbl in enumerate(labels):
if lbl: if lbl:
if predictions[i, 0] >= thresh: if predictions[i, 1] >= thresh:
tp += 1 tp += 1
else: else:
fn += 1 fn += 1
else: else:
if predictions[i, 0] >= thresh: if predictions[i, 1] >= thresh:
fp += 1 fp += 1
else: else:
tn += 1 tn += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册