diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index bba2f3f469f92368fa593a1825cffbefd9952617..c225fa84948fea4c7d42adfbee1e9e963e7f3c9d 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1804,8 +1804,15 @@ def cross_entropy(input, "when weight is provided" .format(input.shape[-1], weight.shape[-1])) - weight_gather = paddle.gather_nd( - weight, label) #trans weight from class to sample, shape:N + valid_label = paddle.where(label == ignore_index, + paddle.zeros_like(label), label) + ignore_weight_mask = paddle.cast((label != ignore_index), input.dtype) + if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ + -1] == 1: + ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1) + weight_gather = paddle.gather_nd(weight, valid_label) + weight_gather = paddle.multiply(weight_gather, ignore_weight_mask) + input_shape = list(label.shape) weight_gather_reshape = reshape(weight_gather, shape=input_shape) out = paddle.multiply(out, weight_gather_reshape, name=weight_name)