diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f4e8711a231e4efcb147b8f227afa16be77dcae3..f8e03e476d7f0c4169f7083d0dd8f1f4ab9d91a5 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1668,12 +1668,12 @@ def cross_entropy(input, format(invalid_label[0], 0)) # TODO: Temporarily use paddle.nonzero instead of paddle.max # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label >= input.shape[-1])) > 0: + if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: invalid_label = paddle.gather_nd( - valid_label, paddle.nonzero(valid_label >= input.shape[-1])) + valid_label, paddle.nonzero(valid_label >= input.shape[axis])) raise ValueError( "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[-1] - 1)) + format(invalid_label[0], input.shape[axis] - 1)) _, out = _C_ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index',