From 618f7620e2f90c52b94ee243bf1827914a04fd82 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Sat, 1 Dec 2018 12:39:36 +0800 Subject: [PATCH] add enforce for auc (#14687) * add enforce for AUC, test=develop --- paddle/fluid/operators/metrics/auc_op.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/metrics/auc_op.h b/paddle/fluid/operators/metrics/auc_op.h index fb370842d1..4ab5cfe53c 100644 --- a/paddle/fluid/operators/metrics/auc_op.h +++ b/paddle/fluid/operators/metrics/auc_op.h @@ -75,8 +75,13 @@ class AucKernel : public framework::OpKernel { const auto *label_data = label->data(); for (size_t i = 0; i < batch_size; i++) { - uint32_t binIdx = static_cast( - inference_data[i * inference_width + 1] * num_thresholds); + auto predict_data = inference_data[i * inference_width + 1]; + PADDLE_ENFORCE_LE(predict_data, 1, + "The predict data must less or equal 1."); + PADDLE_ENFORCE_GE(predict_data, 0, + "The predict data must gather or equal 0."); + + uint32_t binIdx = static_cast(predict_data * num_thresholds); if (label_data[i]) { (*stat_pos)[binIdx] += 1.0; } else { -- GitLab