diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f8e03e476d7f0c4169f7083d0dd8f1f4ab9d91a5..5bb317cf3e746605682ceb2f44146087c74ccf43 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1712,9 +1712,9 @@ def cross_entropy(input, if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: ignore_weight_mask.squeeze_(axis) - if axis != -1: + if axis != -1 and axis != valid_label.ndim - 1: temp_perm = list(range(axis % valid_label.ndim)) \ - + list(range((axis + 1) % valid_label.ndim, valid_label.ndim)) \ + + list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \ + [axis%valid_label.ndim] weight_gather = _C_ops.gather_nd( weight, valid_label.transpose(temp_perm))