diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index c225fa84948fea4c7d42adfbee1e9e963e7f3c9d..1450101206244a1009db96b1e020f2319022c70c 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1806,7 +1806,8 @@ def cross_entropy(input, valid_label = paddle.where(label == ignore_index, paddle.zeros_like(label), label) - ignore_weight_mask = paddle.cast((label != ignore_index), input.dtype) + ignore_weight_mask = paddle.cast( + (label != ignore_index), input.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ -1] == 1: ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1)