From cc9028b90ef50a825a722c55e5fda4b7cd26b0d6 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Sun, 2 Dec 2018 10:17:35 +0800 Subject: [PATCH] cherry-pick enforce for auc (#14687) (#14694) * add enforce for AUC, test=release/1.2 --- paddle/fluid/operators/metrics/auc_op.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/metrics/auc_op.h b/paddle/fluid/operators/metrics/auc_op.h index fb370842d..4ab5cfe53 100644 --- a/paddle/fluid/operators/metrics/auc_op.h +++ b/paddle/fluid/operators/metrics/auc_op.h @@ -75,8 +75,13 @@ class AucKernel : public framework::OpKernel { const auto *label_data = label->data(); for (size_t i = 0; i < batch_size; i++) { - uint32_t binIdx = static_cast( - inference_data[i * inference_width + 1] * num_thresholds); + auto predict_data = inference_data[i * inference_width + 1]; + PADDLE_ENFORCE_LE(predict_data, 1, + "The predict data must less or equal 1."); + PADDLE_ENFORCE_GE(predict_data, 0, + "The predict data must gather or equal 0."); + + uint32_t binIdx = static_cast(predict_data * num_thresholds); if (label_data[i]) { (*stat_pos)[binIdx] += 1.0; } else { -- GitLab