提交 4aa5da05 编写于 作者: Q qiaolongfei

fix auc

上级 4ec9ecae
......@@ -325,14 +325,14 @@ class Auc(MetricBase):
"""
def __init__(self, name, curve='ROC', num_thresholds=200):
super(MetricBase, self).__init__(name, curve, num_thresholds)
super(Auc, self).__init__(name=name)
self._curve = curve
self._num_thresholds = num_thresholds
self._epsilon = 1e-6
self.tp_list = np.ndarray((num_thresholds, ))
self.fn_list = np.ndarray((num_thresholds, ))
self.tn_list = np.ndarray((num_thresholds, ))
self.fp_list = np.ndarray((num_thresholds, ))
self.tp_list = np.zeros((num_thresholds, ))
self.fn_list = np.zeros((num_thresholds, ))
self.tn_list = np.zeros((num_thresholds, ))
self.fp_list = np.zeros((num_thresholds, ))
def update(self, labels, predictions, axis=1):
if not _is_numpy_(labels):
......@@ -350,12 +350,12 @@ class Auc(MetricBase):
tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels):
if lbl:
if predictions[i, 0] >= thresh:
if predictions[i, 1] >= thresh:
tp += 1
else:
fn += 1
else:
if predictions[i, 0] >= thresh:
if predictions[i, 1] >= thresh:
fp += 1
else:
tn += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册