diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 93c1820593e2094ef811d11d79c87121668e9cd3..46ff6bb5fb4d446cfc84ffbbf50ee61e9abe564c 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1811,7 +1811,9 @@ def cross_entropy(input, .format(input.shape[-1], weight.shape[-1])) valid_label = paddle.where( label == ignore_index, - paddle.zeros([1], dtype=label.dtype), label) + paddle.zeros( + [1], dtype=label.dtype), + label) if (len(paddle.nonzero(valid_label < 0)) > 0) or ( len(paddle.nonzero(valid_label >= input.shape[-1])) > 0): invalid_label = paddle.gather_nd(