diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 21b1e444ef844ed587c483a6565b87060008497a..fd4e83a6e8700c184bea1d3bbe4d78e8071c90df 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1651,10 +1651,8 @@ def cross_entropy(input, label = paddle.unsqueeze(label, axis=axis) if in_dygraph_mode(): if soft_label == False: - valid_label = paddle.where( - label == ignore_index, - paddle.zeros_like(label), - label) + valid_label = paddle.where(label == ignore_index, + paddle.zeros_like(label), label) if len(paddle.nonzero(valid_label < 0)) > 0: invalid_label = paddle.gather_nd( valid_label, paddle.nonzero(valid_label < 0)) @@ -1705,8 +1703,7 @@ def cross_entropy(input, if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ -1] == 1: ignore_weight_mask.squeeze_(-1) - weight_gather = _C_ops.gather_nd( - weight, valid_label) + weight_gather = _C_ops.gather_nd(weight, valid_label) weight_gather = _C_ops.elementwise_mul(weight_gather, ignore_weight_mask) input_shape = list(label.shape)