未验证 提交 2b58c62a 编写于 作者: Q Qiao Longfei 提交者: GitHub

Update auc op (#12199)

fix AUC op
optimize it's test
上级 37713f22
...@@ -24,15 +24,16 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -24,15 +24,16 @@ class AucOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out should not be null."); PADDLE_ENFORCE(ctx->HasInput("Predict"),
PADDLE_ENFORCE(ctx->HasInput("Indices"), "Input of Out should not be null.");
"Input of Indices should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label should not be null."); "Input of Label should not be null.");
auto inference_height = ctx->GetInputDim("Out")[0]; auto predict_width = ctx->GetInputDim("Predict")[1];
PADDLE_ENFORCE_EQ(predict_width, 2, "Only support binary classification");
auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0]; auto label_height = ctx->GetInputDim("Label")[0];
PADDLE_ENFORCE_EQ(inference_height, label_height, PADDLE_ENFORCE_EQ(predict_height, label_height,
"Out and Label should have same height."); "Out and Label should have same height.");
int num_thres = ctx->Attrs().Get<int>("num_thresholds"); int num_thres = ctx->Attrs().Get<int>("num_thresholds");
...@@ -43,14 +44,14 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -43,14 +44,14 @@ class AucOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("FPOut", {num_thres}); ctx->SetOutputDim("FPOut", {num_thres});
ctx->SetOutputDim("FNOut", {num_thres}); ctx->SetOutputDim("FNOut", {num_thres});
ctx->ShareLoD("Out", /*->*/ "AUC"); ctx->ShareLoD("Predict", /*->*/ "AUC");
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()), framework::ToDataType(ctx.Input<Tensor>("Predict")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -58,18 +59,13 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -58,18 +59,13 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker { class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Out", AddInput("Predict",
"A floating point 2D tensor, values are in the range [0, 1]." "A floating point 2D tensor with shape [batch_size, 2], values "
"Each row is sorted in descending order. This input should be the" "are in the range [0, 1]."
"output of topk."
"Typically, this tensor indicates the probability of each label"); "Typically, this tensor indicates the probability of each label");
AddInput("Indices",
"An int 2D tensor, indicating the indices of original"
"tensor before sorting. Typically, this tensor indicates which "
"label the probability stands for.");
AddInput("Label", AddInput("Label",
"A 2D int tensor indicating the label of the training data." "A 2D int tensor indicating the label of the training data. "
"The height is batch size and width is always 1."); "shape: [batch_size, 1]");
AddInput("TP", "True-Positive value."); AddInput("TP", "True-Positive value.");
AddInput("FP", "False-Positive value."); AddInput("FP", "False-Positive value.");
AddInput("TN", "True-Negative value."); AddInput("TN", "True-Negative value.");
......
...@@ -31,7 +31,7 @@ template <typename DeviceContext, typename T> ...@@ -31,7 +31,7 @@ template <typename DeviceContext, typename T>
class AucKernel : public framework::OpKernel<T> { class AucKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Out"); auto* predict = ctx.Input<Tensor>("Predict");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* auc = ctx.Output<Tensor>("AUC"); auto* auc = ctx.Output<Tensor>("AUC");
// Only use output var for now, make sure it's persistable and // Only use output var for now, make sure it's persistable and
...@@ -41,24 +41,24 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -41,24 +41,24 @@ class AucKernel : public framework::OpKernel<T> {
auto* true_negative = ctx.Output<Tensor>("TNOut"); auto* true_negative = ctx.Output<Tensor>("TNOut");
auto* false_negative = ctx.Output<Tensor>("FNOut"); auto* false_negative = ctx.Output<Tensor>("FNOut");
float* auc_data = auc->mutable_data<float>(ctx.GetPlace()); auto* auc_data = auc->mutable_data<double>(ctx.GetPlace());
std::string curve = ctx.Attr<std::string>("curve"); std::string curve = ctx.Attr<std::string>("curve");
int num_thresholds = ctx.Attr<int>("num_thresholds"); int num_thresholds = ctx.Attr<int>("num_thresholds");
std::vector<float> thresholds_list; std::vector<double> thresholds_list;
thresholds_list.reserve(num_thresholds); thresholds_list.reserve(num_thresholds);
for (int i = 1; i < num_thresholds - 1; i++) { for (int i = 1; i < num_thresholds - 1; i++) {
thresholds_list[i] = static_cast<float>(i) / (num_thresholds - 1); thresholds_list[i] = static_cast<double>(i) / (num_thresholds - 1);
} }
const float kEpsilon = 1e-7; const double kEpsilon = 1e-7;
thresholds_list[0] = 0.0f - kEpsilon; thresholds_list[0] = 0.0f - kEpsilon;
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
size_t batch_size = inference->dims()[0]; size_t batch_size = predict->dims()[0];
size_t inference_width = inference->dims()[1]; size_t inference_width = predict->dims()[1];
const T* inference_data = inference->data<T>(); const T* inference_data = predict->data<T>();
const int64_t* label_data = label->data<int64_t>(); const auto* label_data = label->data<int64_t>();
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace()); auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace()); auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
...@@ -66,20 +66,19 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -66,20 +66,19 @@ class AucKernel : public framework::OpKernel<T> {
auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace()); auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace());
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
// caculate TP, FN, TN, FP for current thresh // calculate TP, FN, TN, FP for current thresh
int64_t tp = 0, fn = 0, tn = 0, fp = 0; int64_t tp = 0, fn = 0, tn = 0, fp = 0;
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; i++) {
// NOTE: label_data used as bool, labels >0 will be treated as true. // NOTE: label_data used as bool, labels > 0 will be treated as true.
if (label_data[i]) { if (label_data[i]) {
// use first(max) data in each row if (inference_data[i * inference_width + 1] >=
if (inference_data[i * inference_width] >=
(thresholds_list[idx_thresh])) { (thresholds_list[idx_thresh])) {
tp++; tp++;
} else { } else {
fn++; fn++;
} }
} else { } else {
if (inference_data[i * inference_width] >= if (inference_data[i * inference_width + 1] >=
(thresholds_list[idx_thresh])) { (thresholds_list[idx_thresh])) {
fp++; fp++;
} else { } else {
...@@ -94,21 +93,21 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -94,21 +93,21 @@ class AucKernel : public framework::OpKernel<T> {
fp_data[idx_thresh] += fp; fp_data[idx_thresh] += fp;
} }
// epsilon to avoid divide by zero. // epsilon to avoid divide by zero.
float epsilon = 1e-6; double epsilon = 1e-6;
// Riemann sum to caculate auc. // Riemann sum to caculate auc.
Tensor tp_rate, fp_rate, rec_rate; Tensor tp_rate, fp_rate, rec_rate;
tp_rate.Resize({num_thresholds}); tp_rate.Resize({num_thresholds});
fp_rate.Resize({num_thresholds}); fp_rate.Resize({num_thresholds});
rec_rate.Resize({num_thresholds}); rec_rate.Resize({num_thresholds});
float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace()); auto* tp_rate_data = tp_rate.mutable_data<double>(ctx.GetPlace());
float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace()); auto* fp_rate_data = fp_rate.mutable_data<double>(ctx.GetPlace());
float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace()); auto* rec_rate_data = rec_rate.mutable_data<double>(ctx.GetPlace());
for (int i = 0; i < num_thresholds; i++) { for (int i = 0; i < num_thresholds; i++) {
tp_rate_data[i] = (static_cast<float>(tp_data[i]) + epsilon) / tp_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
(tp_data[i] + fn_data[i] + epsilon); (tp_data[i] + fn_data[i] + epsilon);
fp_rate_data[i] = fp_rate_data[i] =
static_cast<float>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon); static_cast<double>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
rec_rate_data[i] = (static_cast<float>(tp_data[i]) + epsilon) / rec_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
(tp_data[i] + fp_data[i] + epsilon); (tp_data[i] + fp_data[i] + epsilon);
} }
*auc_data = 0.0f; *auc_data = 0.0f;
......
...@@ -114,23 +114,13 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): ...@@ -114,23 +114,13 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
prediction = network(image, is_infer=True) prediction = network(image, is_infer=True)
auc_out=fluid.layers.auc(input=prediction, label=label) auc_out=fluid.layers.auc(input=prediction, label=label)
""" """
warnings.warn(
"This interface is not recommended, fluid.layers.auc compute the auc at every minibatch, \
but can not aggregate them and get the pass AUC, because pass \
auc can not be averaged with weighted from the minibatch auc value. \
Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \
which can get every minibatch and every pass auc value.", Warning)
helper = LayerHelper("auc", **locals()) helper = LayerHelper("auc", **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype) auc_out = helper.create_tmp_variable(dtype="float64")
topk_indices = helper.create_tmp_variable(dtype="int64")
topk_out, topk_indices = nn.topk(input, k=k)
auc_out = helper.create_tmp_variable(dtype="float32")
# make tp, tn, fp, fn persistable, so that can accumulate all batches. # make tp, tn, fp, fn persistable, so that can accumulate all batches.
tp = helper.create_global_variable(persistable=True) tp = helper.create_global_variable(persistable=True, dtype='int64')
tn = helper.create_global_variable(persistable=True) tn = helper.create_global_variable(persistable=True, dtype='int64')
fp = helper.create_global_variable(persistable=True) fp = helper.create_global_variable(persistable=True, dtype='int64')
fn = helper.create_global_variable(persistable=True) fn = helper.create_global_variable(persistable=True, dtype='int64')
for var in [tp, tn, fp, fn]: for var in [tp, tn, fp, fn]:
helper.set_variable_initializer( helper.set_variable_initializer(
var, Constant( var, Constant(
...@@ -139,8 +129,7 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): ...@@ -139,8 +129,7 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
helper.append_op( helper.append_op(
type="auc", type="auc",
inputs={ inputs={
"Out": [topk_out], "Predict": [input],
"Indices": [topk_indices],
"Label": [label], "Label": [label],
"TP": [tp], "TP": [tp],
"TN": [tn], "TN": [tn],
...@@ -156,4 +145,4 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): ...@@ -156,4 +145,4 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
"FPOut": [fp], "FPOut": [fp],
"FNOut": [fn] "FNOut": [fn]
}) })
return auc_out return auc_out, [tp, tn, fp, fn]
...@@ -591,7 +591,7 @@ class Auc(MetricBase): ...@@ -591,7 +591,7 @@ class Auc(MetricBase):
for i in range(self._num_thresholds - 2)] for i in range(self._num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
# caculate TP, FN, TN, FP count # calculate TP, FN, TN, FP count
for idx_thresh, thresh in enumerate(thresholds): for idx_thresh, thresh in enumerate(thresholds):
tp, fn, tn, fp = 0, 0, 0, 0 tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels): for i, lbl in enumerate(labels):
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle.fluid import metrics
class TestAucOp(OpTest): class TestAucOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "auc" self.op_type = "auc"
pred = np.random.random((128, 2)).astype("float32") pred = np.random.random((128, 2)).astype("float32")
indices = np.random.randint(0, 2, (128, 2))
labels = np.random.randint(0, 2, (128, 1)) labels = np.random.randint(0, 2, (128, 1))
num_thresholds = 200 num_thresholds = 200
tp = np.zeros((num_thresholds, )).astype("int64") tp = np.zeros((num_thresholds, )).astype("int64")
...@@ -30,8 +30,7 @@ class TestAucOp(OpTest): ...@@ -30,8 +30,7 @@ class TestAucOp(OpTest):
fn = np.zeros((num_thresholds, )).astype("int64") fn = np.zeros((num_thresholds, )).astype("int64")
self.inputs = { self.inputs = {
'Out': pred, 'Predict': pred,
'Indices': indices,
'Label': labels, 'Label': labels,
'TP': tp, 'TP': tp,
'TN': tn, 'TN': tn,
...@@ -39,57 +38,18 @@ class TestAucOp(OpTest): ...@@ -39,57 +38,18 @@ class TestAucOp(OpTest):
'FN': fn 'FN': fn
} }
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds}
# NOTE: sklearn use a different way to generate thresholds
# which will cause the result differs slightly:
# from sklearn.metrics import roc_curve, auc
# fpr, tpr, thresholds = roc_curve(labels, pred)
# auc_value = auc(fpr, tpr)
# we caculate AUC again using numpy for testing
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
# caculate TP, FN, TN, FP count python_auc = metrics.Auc(name="auc",
tp_list = np.ndarray((num_thresholds, )) curve='ROC',
fn_list = np.ndarray((num_thresholds, )) num_thresholds=num_thresholds)
tn_list = np.ndarray((num_thresholds, )) python_auc.update(pred, labels)
fp_list = np.ndarray((num_thresholds, ))
for idx_thresh, thresh in enumerate(thresholds):
tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels):
if lbl:
if pred[i, 0] >= thresh:
tp += 1
else:
fn += 1
else:
if pred[i, 0] >= thresh:
fp += 1
else:
tn += 1
tp_list[idx_thresh] = tp
fn_list[idx_thresh] = fn
tn_list[idx_thresh] = tn
fp_list[idx_thresh] = fp
epsilon = 1e-6
tpr = (tp_list.astype("float32") + epsilon) / (
tp_list + fn_list + epsilon)
fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon)
rec = (tp_list.astype("float32") + epsilon) / (
tp_list + fp_list + epsilon)
x = fpr[:num_thresholds - 1] - fpr[1:]
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
auc_value = np.sum(x * y)
self.outputs = { self.outputs = {
'AUC': auc_value, 'AUC': python_auc.eval(),
'TPOut': tp_list, 'TPOut': python_auc.tp_list,
'FNOut': fn_list, 'FNOut': python_auc.fn_list,
'TNOut': tn_list, 'TNOut': python_auc.tn_list,
'FPOut': fp_list 'FPOut': python_auc.fp_list
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册