From 34c57120ebd42640b46677defea0daa274190484 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Sat, 4 Jan 2020 13:00:18 +0800 Subject: [PATCH] polish cross_entropy ENFORCE (#22056) --- paddle/fluid/operators/math/cross_entropy.cc | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/cross_entropy.cc b/paddle/fluid/operators/math/cross_entropy.cc index 9f7884fe05f..7a1ed47d182 100644 --- a/paddle/fluid/operators/math/cross_entropy.cc +++ b/paddle/fluid/operators/math/cross_entropy.cc @@ -53,7 +53,21 @@ class CrossEntropyFunctor { for (int i = 0; i < batch_size; ++i) { for (int j = 0; j < num_remain; j++) { int lbl = label_data[i * num_remain + j]; - PADDLE_ENFORCE((lbl >= 0 && lbl < axis_dim) || lbl == ignore_index); + if (lbl != ignore_index) { + PADDLE_ENFORCE_GE(lbl, 0, + platform::errors::OutOfRange( + "label value should >= 0 when label " + "value(%f) not equal to ignore_index(%f)", + lbl, ignore_index)); + PADDLE_ENFORCE_LT( + lbl, axis_dim, + platform::errors::OutOfRange( + "label value should less than the shape of axis dimension " + "when label value(%f) not equal to ignore_index(%f), But " + "received label value as %ld and shape of axis dimension " + "is %d", + lbl, ignore_index, lbl, axis_dim)); + } int index = i * num_classes + lbl * num_remain + j; int loss_idx = i * num_remain + j; loss_data[loss_idx] = -- GitLab