From 399a5eec69a34d6336858179080ae3e5dc67ee90 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 13 Sep 2017 12:45:23 +0800 Subject: [PATCH] auc_op --- paddle/operators/auc_op.cc | 34 ++++++++++++++-------------- paddle/operators/auc_op.h | 45 ++++++++++++++++++++++---------------- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index 3a43f9bcc48..63f0d50fdc3 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -28,15 +28,12 @@ class AucOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), "Input of Inference must be initialized."); auto *inference = ctx.Input("Inference"); - auto *inference_prob = ctx.Input("InferenceProb"); auto *label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); - PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], - "inference size must be the same as label size"); - PADDLE_ENFORCE_EQ(inference->dims(), inference_prob->dims()); + PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), + "inference should have same shape as label"); - ctx.Output("Accuracy")->Resize({1}); + ctx.Output("AUC")->Resize({1}); } }; @@ -45,14 +42,15 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Inference", - "Topk(indices) the network output, float value indicating " - "probabilities of classification"); - AddInput("InferenceProb", - "Topk(values) the network output, float value indicating " - "probabilities of classification"); - AddInput("Label", "Label of the training data"); - // TODO(typhoonzero): support weight - AddOutput("AUC", "Area Under Curve caculations"); + "A floating point `Tensor` of arbitrary shape and whose values" + "are in the range `[0, 1]`."); + AddInput("Label", + "A `Tensor` whose shape matches " + "`Inference`. Will be cast to `bool`."); + // TODO(typhoonzero): support weight input + AddOutput("AUC", + "A scalar `Tensor` representing the " + "current area-under-curve."); AddAttr("curve", "Possible curves are ROC and PR") .SetDefault("ROC"); AddAttr("num_thresholds", @@ -62,12 +60,16 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AddComment( R"DOC(Computes the AUC according forward output and label. + Best to use for binary classification evaluations. + If `label` can be values other than 0 and 1, it will be cast + to bool. + You can find the definations here: https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve Possible curves are: - ROC: Receiver operating characteristic - PR: Precision Recall + - ROC: Receiver operating characteristic + - PR: Precision Recall )DOC"); } }; diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index fd110c06e64..b6ca74f1af2 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -22,12 +22,15 @@ namespace operators { using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + template class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Inference"); - auto* inference_prob = ctx.Input("InferenceProb"); auto* label = ctx.Input("Label"); auto* auc = ctx.Output("AUC"); @@ -44,14 +47,20 @@ class AucKernel : public framework::OpKernel { thresholds_list[0] = 0.0f - kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; - const int* inference_data = inference->data(); - const T* inference_prob_data = inference_prob->data(); - const T* label_data = label->data(); + size_t num_samples = inference->numel(); + + const T* inference_data = inference->data(); + Tensor label_casted; + label_casted.Resize(label->dims()); + bool* label_casted_data = label_casted.mutable_data(ctx.GetPlace()); - size_t num_samples = inference->dims()[0]; - size_t class_dim = inference->dims()[1]; + const int* label_data = label->data(); + // cast label_data to bool + for (size_t i = 0; i < num_samples; i++) { + label_casted_data[i] = static_cast(label_data[i]); + } - // create local tensor for storing the curve: TP, FN, TN, FP + // Create local tensor for storing the curve: TP, FN, TN, FP // TODO(typhoonzero): put these tensors in Scope // TODO(typhoonzero): use op to caculate these values. Tensor true_positive, false_positive, true_negative, false_negative; @@ -72,19 +81,17 @@ class AucKernel : public framework::OpKernel { // caculate TP, FN, TN, FP for current thresh int tp, fn, tn, fp = 0; for (size_t i = 0; i < num_samples; i++) { - for (size_t j = 0; j < class_dim; j++) { - if (inference_data[i * class_dim + j] == label_data[i]) { - if (inference_prob_data[i * class_dim + j] >= (*thresh)) { - tp++; - } else { - tn++; - } + if (label_casted_data[i]) { + if (inference_data[i] >= (*thresh)) { + tp++; + } else { + tn++; + } + } else { + if (inference_data[i] >= (*thresh)) { + fp++; } else { - if (inference_prob_data[i * class_dim + j] >= (*thresh)) { - fp++; - } else { - fn++; - } + fn++; } } } -- GitLab