diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index c1800a781d4ba08e7d479801f79fc87db77f5513..4d09f1d5c38fc8cfc0be3cd0ebc16fc3e89edbc5 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1703,11 +1703,16 @@ def cross_entropy(input, "when weight is provided" \ .format(input.shape[axis], weight.shape[-1])) - valid_label = paddle.where(label == ignore_index, - paddle.zeros_like(label), label) + ignore_weight_mask = ( + label != ignore_index) # ignored position will be False + + valid_label = paddle.cast( + ignore_weight_mask, + dtype=label.dtype) * label # ignored position will be 0 + + ignore_weight_mask = paddle.cast( + ignore_weight_mask, out.dtype) # convert from 0 to 0.0 - ignore_weight_mask = paddle.cast((label != ignore_index), - out.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: # TODO: Temporarily use squeeze instead of squeeze_ @@ -1821,10 +1826,16 @@ def cross_entropy(input, "when weight is provided" \ .format(input.shape[axis], weight.shape[-1])) - valid_label = paddle.where(label == ignore_index, - paddle.zeros_like(label), label) - ignore_weight_mask = paddle.cast((label != ignore_index), - input.dtype) + ignore_weight_mask = ( + label != ignore_index) # ignored position will be False + + valid_label = paddle.cast( + ignore_weight_mask, + dtype=label.dtype) * label # ignored position will be 0 + + ignore_weight_mask = paddle.cast(ignore_weight_mask, + out.dtype) # convert from 0 to 0.0 + if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)