diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index b70c1cf2496a6c55c3fecda8339bcc40ea25c4f5..cb7256a4d9706d312b7cb2ba755cd54c2855fe57 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2706,12 +2706,13 @@ def cross_entropy( # 2. else # numerator: loss's weighted sum # denominator: cal the sum of weight where the sample's class_index!=ignore_index - if ignore_index >= 0: + is_ignore = label == ignore_index + mask = ~is_ignore + if paddle.count_nonzero(is_ignore) > 0: # ignore label out_sum = _C_ops.sum(out, [], None, False) # for each label[i],set 1 or 0, according to ignore_index # mask[i]=0, if label[i]==ignore_index # mask[i]=1, otherwise - mask = label != ignore_index if weight is None: mask = paddle.cast(mask, dtype=out_sum.dtype) count = _C_ops.sum(mask, [], None, False) @@ -2878,12 +2879,13 @@ def cross_entropy( # 2. else # numerator: loss's weighted sum # denominator: cal the sum of weight where the sample's class_index!=ignore_index - if ignore_index >= 0: + is_ignore = label == ignore_index + mask = ~is_ignore + if paddle.count_nonzero(is_ignore) > 0: # ignore label out_sum = _legacy_C_ops.reduce_sum(out, 'reduce_all', True) # for each label[i],set 1 or 0, according to ignore_index # mask[i]=0, if label[i]==ignore_index # mask[i]=1, otherwise - mask = label != ignore_index if weight is None: mask = paddle.cast(mask, dtype=out_sum.dtype) count = _legacy_C_ops.reduce_sum(mask, 'reduce_all', True)