diff --git a/paddle/fluid/operators/auc_op.cc b/paddle/fluid/operators/auc_op.cc index 6bd3e491bccb037406b784147dc9f91049b34d53..5edecd18e673da326ec119cf9a383f24f8045089 100644 --- a/paddle/fluid/operators/auc_op.cc +++ b/paddle/fluid/operators/auc_op.cc @@ -24,15 +24,16 @@ class AucOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Indices"), - "Input of Indices should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Predict"), + "Input of Out should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input of Label should not be null."); - auto inference_height = ctx->GetInputDim("Out")[0]; + auto predict_width = ctx->GetInputDim("Predict")[1]; + PADDLE_ENFORCE_EQ(predict_width, 2, "Only support binary classification"); + auto predict_height = ctx->GetInputDim("Predict")[0]; auto label_height = ctx->GetInputDim("Label")[0]; - PADDLE_ENFORCE_EQ(inference_height, label_height, + PADDLE_ENFORCE_EQ(predict_height, label_height, "Out and Label should have same height."); int num_thres = ctx->Attrs().Get("num_thresholds"); @@ -43,14 +44,14 @@ class AucOp : public framework::OperatorWithKernel { ctx->SetOutputDim("FPOut", {num_thres}); ctx->SetOutputDim("FNOut", {num_thres}); - ctx->ShareLoD("Out", /*->*/ "AUC"); + ctx->ShareLoD("Predict", /*->*/ "AUC"); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Out")->type()), + framework::ToDataType(ctx.Input("Predict")->type()), ctx.device_context()); } }; @@ -58,18 +59,13 @@ class AucOp : public framework::OperatorWithKernel { class AucOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("Out", - "A floating point 2D tensor, values are in the range [0, 1]." - "Each row is sorted in descending order. This input should be the" - "output of topk." + AddInput("Predict", + "A floating point 2D tensor with shape [batch_size, 2], values " + "are in the range [0, 1]." "Typically, this tensor indicates the probability of each label"); - AddInput("Indices", - "An int 2D tensor, indicating the indices of original" - "tensor before sorting. Typically, this tensor indicates which " - "label the probability stands for."); AddInput("Label", - "A 2D int tensor indicating the label of the training data." - "The height is batch size and width is always 1."); + "A 2D int tensor indicating the label of the training data. " + "shape: [batch_size, 1]"); AddInput("TP", "True-Positive value."); AddInput("FP", "False-Positive value."); AddInput("TN", "True-Negative value."); diff --git a/paddle/fluid/operators/auc_op.h b/paddle/fluid/operators/auc_op.h index 58fefc1600dfb7df3e3d71959c047865ed5e2e39..0a18585edb54a76aff5ae72ecc71e0eebb9f9361 100644 --- a/paddle/fluid/operators/auc_op.h +++ b/paddle/fluid/operators/auc_op.h @@ -31,7 +31,7 @@ template class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* inference = ctx.Input("Out"); + auto* predict = ctx.Input("Predict"); auto* label = ctx.Input("Label"); auto* auc = ctx.Output("AUC"); // Only use output var for now, make sure it's persistable and @@ -41,24 +41,24 @@ class AucKernel : public framework::OpKernel { auto* true_negative = ctx.Output("TNOut"); auto* false_negative = ctx.Output("FNOut"); - float* auc_data = auc->mutable_data(ctx.GetPlace()); + auto* auc_data = auc->mutable_data(ctx.GetPlace()); std::string curve = ctx.Attr("curve"); int num_thresholds = ctx.Attr("num_thresholds"); - std::vector thresholds_list; + std::vector thresholds_list; thresholds_list.reserve(num_thresholds); for (int i = 1; i < num_thresholds - 1; i++) { - thresholds_list[i] = static_cast(i) / (num_thresholds - 1); + thresholds_list[i] = static_cast(i) / (num_thresholds - 1); } - const float kEpsilon = 1e-7; + const double kEpsilon = 1e-7; thresholds_list[0] = 0.0f - kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; - size_t batch_size = inference->dims()[0]; - size_t inference_width = inference->dims()[1]; + size_t batch_size = predict->dims()[0]; + size_t inference_width = predict->dims()[1]; - const T* inference_data = inference->data(); - const int64_t* label_data = label->data(); + const T* inference_data = predict->data(); + const auto* label_data = label->data(); auto* tp_data = true_positive->mutable_data(ctx.GetPlace()); auto* fn_data = false_negative->mutable_data(ctx.GetPlace()); @@ -66,20 +66,19 @@ class AucKernel : public framework::OpKernel { auto* fp_data = false_positive->mutable_data(ctx.GetPlace()); for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { - // caculate TP, FN, TN, FP for current thresh + // calculate TP, FN, TN, FP for current thresh int64_t tp = 0, fn = 0, tn = 0, fp = 0; for (size_t i = 0; i < batch_size; i++) { - // NOTE: label_data used as bool, labels >0 will be treated as true. + // NOTE: label_data used as bool, labels > 0 will be treated as true. if (label_data[i]) { - // use first(max) data in each row - if (inference_data[i * inference_width] >= + if (inference_data[i * inference_width + 1] >= (thresholds_list[idx_thresh])) { tp++; } else { fn++; } } else { - if (inference_data[i * inference_width] >= + if (inference_data[i * inference_width + 1] >= (thresholds_list[idx_thresh])) { fp++; } else { @@ -94,21 +93,21 @@ class AucKernel : public framework::OpKernel { fp_data[idx_thresh] += fp; } // epsilon to avoid divide by zero. - float epsilon = 1e-6; + double epsilon = 1e-6; // Riemann sum to caculate auc. Tensor tp_rate, fp_rate, rec_rate; tp_rate.Resize({num_thresholds}); fp_rate.Resize({num_thresholds}); rec_rate.Resize({num_thresholds}); - float* tp_rate_data = tp_rate.mutable_data(ctx.GetPlace()); - float* fp_rate_data = fp_rate.mutable_data(ctx.GetPlace()); - float* rec_rate_data = rec_rate.mutable_data(ctx.GetPlace()); + auto* tp_rate_data = tp_rate.mutable_data(ctx.GetPlace()); + auto* fp_rate_data = fp_rate.mutable_data(ctx.GetPlace()); + auto* rec_rate_data = rec_rate.mutable_data(ctx.GetPlace()); for (int i = 0; i < num_thresholds; i++) { - tp_rate_data[i] = (static_cast(tp_data[i]) + epsilon) / + tp_rate_data[i] = (static_cast(tp_data[i]) + epsilon) / (tp_data[i] + fn_data[i] + epsilon); fp_rate_data[i] = - static_cast(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon); - rec_rate_data[i] = (static_cast(tp_data[i]) + epsilon) / + static_cast(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon); + rec_rate_data[i] = (static_cast(tp_data[i]) + epsilon) / (tp_data[i] + fp_data[i] + epsilon); } *auc_data = 0.0f; diff --git a/python/paddle/fluid/layers/metric_op.py b/python/paddle/fluid/layers/metric_op.py index 194a16b123c441ac1318b8ce58158f67e2a8093d..e7d7a9e826de95514b6f2e04e7408075ab0b8cb6 100644 --- a/python/paddle/fluid/layers/metric_op.py +++ b/python/paddle/fluid/layers/metric_op.py @@ -114,23 +114,13 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): prediction = network(image, is_infer=True) auc_out=fluid.layers.auc(input=prediction, label=label) """ - - warnings.warn( - "This interface is not recommended, fluid.layers.auc compute the auc at every minibatch, \ - but can not aggregate them and get the pass AUC, because pass \ - auc can not be averaged with weighted from the minibatch auc value. \ - Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \ - which can get every minibatch and every pass auc value.", Warning) helper = LayerHelper("auc", **locals()) - topk_out = helper.create_tmp_variable(dtype=input.dtype) - topk_indices = helper.create_tmp_variable(dtype="int64") - topk_out, topk_indices = nn.topk(input, k=k) - auc_out = helper.create_tmp_variable(dtype="float32") + auc_out = helper.create_tmp_variable(dtype="float64") # make tp, tn, fp, fn persistable, so that can accumulate all batches. - tp = helper.create_global_variable(persistable=True) - tn = helper.create_global_variable(persistable=True) - fp = helper.create_global_variable(persistable=True) - fn = helper.create_global_variable(persistable=True) + tp = helper.create_global_variable(persistable=True, dtype='int64') + tn = helper.create_global_variable(persistable=True, dtype='int64') + fp = helper.create_global_variable(persistable=True, dtype='int64') + fn = helper.create_global_variable(persistable=True, dtype='int64') for var in [tp, tn, fp, fn]: helper.set_variable_initializer( var, Constant( @@ -139,8 +129,7 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): helper.append_op( type="auc", inputs={ - "Out": [topk_out], - "Indices": [topk_indices], + "Predict": [input], "Label": [label], "TP": [tp], "TN": [tn], @@ -156,4 +145,4 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): "FPOut": [fp], "FNOut": [fn] }) - return auc_out + return auc_out, [tp, tn, fp, fn] diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 17bb0826a6ea86c98a069263dfab84b99e1177ad..b37b09ac81687882443c948569d9c4fca9310f78 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -591,7 +591,7 @@ class Auc(MetricBase): for i in range(self._num_thresholds - 2)] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] - # caculate TP, FN, TN, FP count + # calculate TP, FN, TN, FP count for idx_thresh, thresh in enumerate(thresholds): tp, fn, tn, fp = 0, 0, 0, 0 for i, lbl in enumerate(labels): diff --git a/python/paddle/fluid/tests/unittests/test_auc_op.py b/python/paddle/fluid/tests/unittests/test_auc_op.py index 6bd5e2332a99693f5e53e147491aa83c35859548..6580c70ca68c4ba24919f03d071f6f88fb68953c 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_op.py @@ -15,13 +15,13 @@ import unittest import numpy as np from op_test import OpTest +from paddle.fluid import metrics class TestAucOp(OpTest): def setUp(self): self.op_type = "auc" pred = np.random.random((128, 2)).astype("float32") - indices = np.random.randint(0, 2, (128, 2)) labels = np.random.randint(0, 2, (128, 1)) num_thresholds = 200 tp = np.zeros((num_thresholds, )).astype("int64") @@ -30,8 +30,7 @@ class TestAucOp(OpTest): fn = np.zeros((num_thresholds, )).astype("int64") self.inputs = { - 'Out': pred, - 'Indices': indices, + 'Predict': pred, 'Label': labels, 'TP': tp, 'TN': tn, @@ -39,57 +38,18 @@ class TestAucOp(OpTest): 'FN': fn } self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} - # NOTE: sklearn use a different way to generate thresholds - # which will cause the result differs slightly: - # from sklearn.metrics import roc_curve, auc - # fpr, tpr, thresholds = roc_curve(labels, pred) - # auc_value = auc(fpr, tpr) - # we caculate AUC again using numpy for testing - kepsilon = 1e-7 # to account for floating point imprecisions - thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) - for i in range(num_thresholds - 2)] - thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] - # caculate TP, FN, TN, FP count - tp_list = np.ndarray((num_thresholds, )) - fn_list = np.ndarray((num_thresholds, )) - tn_list = np.ndarray((num_thresholds, )) - fp_list = np.ndarray((num_thresholds, )) - for idx_thresh, thresh in enumerate(thresholds): - tp, fn, tn, fp = 0, 0, 0, 0 - for i, lbl in enumerate(labels): - if lbl: - if pred[i, 0] >= thresh: - tp += 1 - else: - fn += 1 - else: - if pred[i, 0] >= thresh: - fp += 1 - else: - tn += 1 - tp_list[idx_thresh] = tp - fn_list[idx_thresh] = fn - tn_list[idx_thresh] = tn - fp_list[idx_thresh] = fp - - epsilon = 1e-6 - tpr = (tp_list.astype("float32") + epsilon) / ( - tp_list + fn_list + epsilon) - fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon) - rec = (tp_list.astype("float32") + epsilon) / ( - tp_list + fp_list + epsilon) - - x = fpr[:num_thresholds - 1] - fpr[1:] - y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 - auc_value = np.sum(x * y) + python_auc = metrics.Auc(name="auc", + curve='ROC', + num_thresholds=num_thresholds) + python_auc.update(pred, labels) self.outputs = { - 'AUC': auc_value, - 'TPOut': tp_list, - 'FNOut': fn_list, - 'TNOut': tn_list, - 'FPOut': fp_list + 'AUC': python_auc.eval(), + 'TPOut': python_auc.tp_list, + 'FNOut': python_auc.fn_list, + 'TNOut': python_auc.tn_list, + 'FPOut': python_auc.fp_list } def test_check_output(self):