diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 831d5e9207d0d8ab7796b4cfdba982fab3fd3359..3f12fb759d16a236d3292f6c0fb580300ffce211 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1653,12 +1653,16 @@ def cross_entropy(input, if soft_label == False: valid_label = paddle.where(label == ignore_index, paddle.zeros_like(label), label) + # TODO: Temporarily use paddle.nonzero instead of paddle.max + # to detect and find out possible illegal label values if len(paddle.nonzero(valid_label < 0)) > 0: invalid_label = paddle.gather_nd( valid_label, paddle.nonzero(valid_label < 0)) raise ValueError( "Target({}) is out of class_dimension's lower bound({})". format(invalid_label[0], 0)) + # TODO: Temporarily use paddle.nonzero instead of paddle.max + # to detect and find out possible illegal label values if len(paddle.nonzero(valid_label >= input.shape[-1])) > 0: invalid_label = paddle.gather_nd( valid_label, paddle.nonzero(valid_label >= input.shape[-1]))