diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index cf4d5b1ed35afebfe23e1fea1a436c7a633a878c..eeb00625876468fac7ce3d1ebefd4b46a796d2c0 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1411,11 +1411,13 @@ def cross_entropy(input, out = core.ops.elementwise_mul(out, weight_gather_reshape) else: - for label_val in label.flatten(): - if label_val < 0 or label_val >= input.shape[-1]: - raise ValueError( - 'Expected 0 <= label_value < class_dimension({}), but got label_value {}'. - format(input.shape[-1], label_val.numpy())) + label_min = paddle.min(label) + label_max = paddle.max(label) + if label_min < 0 or label_max >= input.shape[-1]: + raise ValueError( + 'Expected 0 <= label_value < class_dimension({}), but got {} <= label_value <= {} '. + format(input.shape[-1], + label_min.numpy(), label_max.numpy())) weight_gather = core.ops.gather_nd(weight, label) input_shape = list(label.shape) weight_gather_reshape = reshape(