提交 399a5eec 编写于 作者: T typhoonzero

auc_op

上级 f4e31347
...@@ -28,15 +28,12 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -28,15 +28,12 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input of Inference must be initialized."); "Input of Inference must be initialized.");
auto *inference = ctx.Input<framework::Tensor>("Inference"); auto *inference = ctx.Input<framework::Tensor>("Inference");
auto *inference_prob = ctx.Input<framework::Tensor>("InferenceProb");
auto *label = ctx.Input<framework::Tensor>("Label"); auto *label = ctx.Input<framework::Tensor>("Label");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); PADDLE_ENFORCE_EQ(inference->dims(), label->dims(),
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], "inference should have same shape as label");
"inference size must be the same as label size");
PADDLE_ENFORCE_EQ(inference->dims(), inference_prob->dims());
ctx.Output<Tensor>("Accuracy")->Resize({1}); ctx.Output<Tensor>("AUC")->Resize({1});
} }
}; };
...@@ -45,14 +42,15 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -45,14 +42,15 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Inference", AddInput("Inference",
"Topk(indices) the network output, float value indicating " "A floating point `Tensor` of arbitrary shape and whose values"
"probabilities of classification"); "are in the range `[0, 1]`.");
AddInput("InferenceProb", AddInput("Label",
"Topk(values) the network output, float value indicating " "A `Tensor` whose shape matches "
"probabilities of classification"); "`Inference`. Will be cast to `bool`.");
AddInput("Label", "Label of the training data"); // TODO(typhoonzero): support weight input
// TODO(typhoonzero): support weight AddOutput("AUC",
AddOutput("AUC", "Area Under Curve caculations"); "A scalar `Tensor` representing the "
"current area-under-curve.");
AddAttr<std::string>("curve", "Possible curves are ROC and PR") AddAttr<std::string>("curve", "Possible curves are ROC and PR")
.SetDefault("ROC"); .SetDefault("ROC");
AddAttr<int>("num_thresholds", AddAttr<int>("num_thresholds",
...@@ -62,12 +60,16 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -62,12 +60,16 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment( AddComment(
R"DOC(Computes the AUC according forward output and label. R"DOC(Computes the AUC according forward output and label.
Best to use for binary classification evaluations.
If `label` can be values other than 0 and 1, it will be cast
to bool.
You can find the definations here: You can find the definations here:
https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
Possible curves are: Possible curves are:
ROC: Receiver operating characteristic - ROC: Receiver operating characteristic
PR: Precision Recall - PR: Precision Recall
)DOC"); )DOC");
} }
}; };
......
...@@ -22,12 +22,15 @@ namespace operators { ...@@ -22,12 +22,15 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class AucKernel : public framework::OpKernel { class AucKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference"); auto* inference = ctx.Input<Tensor>("Inference");
auto* inference_prob = ctx.Input<Tensor>("InferenceProb");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* auc = ctx.Output<Tensor>("AUC"); auto* auc = ctx.Output<Tensor>("AUC");
...@@ -44,14 +47,20 @@ class AucKernel : public framework::OpKernel { ...@@ -44,14 +47,20 @@ class AucKernel : public framework::OpKernel {
thresholds_list[0] = 0.0f - kEpsilon; thresholds_list[0] = 0.0f - kEpsilon;
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
const int* inference_data = inference->data<int>(); size_t num_samples = inference->numel();
const T* inference_prob_data = inference_prob->data<T>();
const T* label_data = label->data<T>(); const T* inference_data = inference->data<T>();
Tensor label_casted;
label_casted.Resize(label->dims());
bool* label_casted_data = label_casted.mutable_data<bool>(ctx.GetPlace());
size_t num_samples = inference->dims()[0]; const int* label_data = label->data<int>();
size_t class_dim = inference->dims()[1]; // cast label_data to bool
for (size_t i = 0; i < num_samples; i++) {
label_casted_data[i] = static_cast<bool>(label_data[i]);
}
// create local tensor for storing the curve: TP, FN, TN, FP // Create local tensor for storing the curve: TP, FN, TN, FP
// TODO(typhoonzero): put these tensors in Scope // TODO(typhoonzero): put these tensors in Scope
// TODO(typhoonzero): use op to caculate these values. // TODO(typhoonzero): use op to caculate these values.
Tensor true_positive, false_positive, true_negative, false_negative; Tensor true_positive, false_positive, true_negative, false_negative;
...@@ -72,19 +81,17 @@ class AucKernel : public framework::OpKernel { ...@@ -72,19 +81,17 @@ class AucKernel : public framework::OpKernel {
// caculate TP, FN, TN, FP for current thresh // caculate TP, FN, TN, FP for current thresh
int tp, fn, tn, fp = 0; int tp, fn, tn, fp = 0;
for (size_t i = 0; i < num_samples; i++) { for (size_t i = 0; i < num_samples; i++) {
for (size_t j = 0; j < class_dim; j++) { if (label_casted_data[i]) {
if (inference_data[i * class_dim + j] == label_data[i]) { if (inference_data[i] >= (*thresh)) {
if (inference_prob_data[i * class_dim + j] >= (*thresh)) { tp++;
tp++; } else {
} else { tn++;
tn++; }
} } else {
if (inference_data[i] >= (*thresh)) {
fp++;
} else { } else {
if (inference_prob_data[i * class_dim + j] >= (*thresh)) { fn++;
fp++;
} else {
fn++;
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册