提交 399a5eec 编写于 作者: T typhoonzero

auc_op

上级 f4e31347
......@@ -28,15 +28,12 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input of Inference must be initialized.");
auto *inference = ctx.Input<framework::Tensor>("Inference");
auto *inference_prob = ctx.Input<framework::Tensor>("InferenceProb");
auto *label = ctx.Input<framework::Tensor>("Label");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector");
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0],
"inference size must be the same as label size");
PADDLE_ENFORCE_EQ(inference->dims(), inference_prob->dims());
PADDLE_ENFORCE_EQ(inference->dims(), label->dims(),
"inference should have same shape as label");
ctx.Output<Tensor>("Accuracy")->Resize({1});
ctx.Output<Tensor>("AUC")->Resize({1});
}
};
......@@ -45,14 +42,15 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Inference",
"Topk(indices) the network output, float value indicating "
"probabilities of classification");
AddInput("InferenceProb",
"Topk(values) the network output, float value indicating "
"probabilities of classification");
AddInput("Label", "Label of the training data");
// TODO(typhoonzero): support weight
AddOutput("AUC", "Area Under Curve caculations");
"A floating point `Tensor` of arbitrary shape and whose values"
"are in the range `[0, 1]`.");
AddInput("Label",
"A `Tensor` whose shape matches "
"`Inference`. Will be cast to `bool`.");
// TODO(typhoonzero): support weight input
AddOutput("AUC",
"A scalar `Tensor` representing the "
"current area-under-curve.");
AddAttr<std::string>("curve", "Possible curves are ROC and PR")
.SetDefault("ROC");
AddAttr<int>("num_thresholds",
......@@ -62,12 +60,16 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(
R"DOC(Computes the AUC according forward output and label.
Best to use for binary classification evaluations.
If `label` can be values other than 0 and 1, it will be cast
to bool.
You can find the definations here:
https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
Possible curves are:
ROC: Receiver operating characteristic
PR: Precision Recall
- ROC: Receiver operating characteristic
- PR: Precision Recall
)DOC");
}
};
......
......@@ -22,12 +22,15 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class AucKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference");
auto* inference_prob = ctx.Input<Tensor>("InferenceProb");
auto* label = ctx.Input<Tensor>("Label");
auto* auc = ctx.Output<Tensor>("AUC");
......@@ -44,14 +47,20 @@ class AucKernel : public framework::OpKernel {
thresholds_list[0] = 0.0f - kEpsilon;
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
const int* inference_data = inference->data<int>();
const T* inference_prob_data = inference_prob->data<T>();
const T* label_data = label->data<T>();
size_t num_samples = inference->numel();
size_t num_samples = inference->dims()[0];
size_t class_dim = inference->dims()[1];
const T* inference_data = inference->data<T>();
Tensor label_casted;
label_casted.Resize(label->dims());
bool* label_casted_data = label_casted.mutable_data<bool>(ctx.GetPlace());
// create local tensor for storing the curve: TP, FN, TN, FP
const int* label_data = label->data<int>();
// cast label_data to bool
for (size_t i = 0; i < num_samples; i++) {
label_casted_data[i] = static_cast<bool>(label_data[i]);
}
// Create local tensor for storing the curve: TP, FN, TN, FP
// TODO(typhoonzero): put these tensors in Scope
// TODO(typhoonzero): use op to caculate these values.
Tensor true_positive, false_positive, true_negative, false_negative;
......@@ -72,22 +81,20 @@ class AucKernel : public framework::OpKernel {
// caculate TP, FN, TN, FP for current thresh
int tp, fn, tn, fp = 0;
for (size_t i = 0; i < num_samples; i++) {
for (size_t j = 0; j < class_dim; j++) {
if (inference_data[i * class_dim + j] == label_data[i]) {
if (inference_prob_data[i * class_dim + j] >= (*thresh)) {
if (label_casted_data[i]) {
if (inference_data[i] >= (*thresh)) {
tp++;
} else {
tn++;
}
} else {
if (inference_prob_data[i * class_dim + j] >= (*thresh)) {
if (inference_data[i] >= (*thresh)) {
fp++;
} else {
fn++;
}
}
}
}
// store rates
tp_data[idx_thresh] = tp;
fn_data[idx_thresh] = fn;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部