提交 c7eef34c 编写于 作者: T typhoonzero

auc cpu only

上级 399a5eec
......@@ -31,9 +31,9 @@ class AucOp : public framework::OperatorWithKernel {
auto *label = ctx.Input<framework::Tensor>("Label");
PADDLE_ENFORCE_EQ(inference->dims(), label->dims(),
"inference should have same shape as label");
"inference and label should have same shape");
ctx.Output<Tensor>("AUC")->Resize({1});
ctx.Output<framework::Tensor>("AUC")->Resize({1});
}
};
......@@ -51,6 +51,7 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("AUC",
"A scalar `Tensor` representing the "
"current area-under-curve.");
AddAttr<std::string>("curve", "Possible curves are ROC and PR")
.SetDefault("ROC");
AddAttr<int>("num_thresholds",
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
......@@ -75,23 +75,21 @@ class AucKernel : public framework::OpKernel {
int* tn_data = true_negative.mutable_data<int>(ctx.GetPlace());
int* fp_data = false_positive.mutable_data<int>(ctx.GetPlace());
for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end();
thresh++) {
size_t idx_thresh = thresh - thresholds_list.begin();
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
// caculate TP, FN, TN, FP for current thresh
int tp, fn, tn, fp = 0;
int tp = 0, fn = 0, tn = 0, fp = 0;
for (size_t i = 0; i < num_samples; i++) {
if (label_casted_data[i]) {
if (inference_data[i] >= (*thresh)) {
if (inference_data[i] >= (thresholds_list[idx_thresh])) {
tp++;
} else {
tn++;
fn++;
}
} else {
if (inference_data[i] >= (*thresh)) {
if (inference_data[i] >= (thresholds_list[idx_thresh])) {
fp++;
} else {
fn++;
tn++;
}
}
}
......@@ -118,11 +116,11 @@ class AucKernel : public framework::OpKernel {
rec_rate_data[i] =
((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon);
}
*auc_data = 0.0f;
if (curve == "ROC") {
for (int i = 1; i < num_thresholds; i++) {
auto dx = fp_rate_data[i] - fp_rate_data[i - 1];
auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f;
for (int i = 0; i < num_thresholds - 1; i++) {
auto dx = fp_rate_data[i] - fp_rate_data[i + 1];
auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f;
*auc_data = *auc_data + dx * y;
}
} else if (curve == "PR") {
......
import unittest
import numpy as np
from op_test import OpTest
class TestAucOp(OpTest):
def setUp(self):
self.op_type = "auc"
pred = np.random.random((128)).astype("float32")
labels = np.random.randint(0, 2, (128, ))
num_thresholds = 200
self.inputs = {'Inference': pred, 'Label': labels}
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds}
# NOTE: sklearn use a different way to generate thresholds
# which will cause the result differs slightly:
# from sklearn.metrics import roc_curve, auc
# fpr, tpr, thresholds = roc_curve(labels, pred)
# auc_value = auc(fpr, tpr)
# we caculate AUC again using numpy for testing
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
# caculate TP, FN, TN, FP count
tp_list = np.ndarray((num_thresholds, ))
fn_list = np.ndarray((num_thresholds, ))
tn_list = np.ndarray((num_thresholds, ))
fp_list = np.ndarray((num_thresholds, ))
for idx_thresh, thresh in enumerate(thresholds):
tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels):
if lbl:
if pred[i] >= thresh:
tp += 1
else:
fn += 1
else:
if pred[i] >= thresh:
fp += 1
else:
tn += 1
tp_list[idx_thresh] = tp
fn_list[idx_thresh] = fn
tn_list[idx_thresh] = tn
fp_list[idx_thresh] = fp
epsilon = 1e-6
tpr = (tp_list.astype("float32") + epsilon) / (
tp_list + fn_list + epsilon)
fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon)
rec = (tp_list.astype("float32") + epsilon) / (
tp_list + fp_list + epsilon)
x = fpr[:num_thresholds - 1] - fpr[1:]
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
auc_value = np.sum(x * y)
self.outputs = {'AUC': auc_value}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
......@@ -21,6 +21,9 @@ class TestTopkOp(OpTest):
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
self.check_output()
class TestTopkOp3d(OpTest):
def setUp(self):
......@@ -42,6 +45,9 @@ class TestTopkOp3d(OpTest):
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册