未验证 提交 80d465ee 编写于 作者: H huangjun12 提交者: GitHub

fix ce bug (#48974)

上级 0f2683d6
...@@ -2706,12 +2706,13 @@ def cross_entropy( ...@@ -2706,12 +2706,13 @@ def cross_entropy(
# 2. else # 2. else
# numerator: loss's weighted sum # numerator: loss's weighted sum
# denominator: cal the sum of weight where the sample's class_index!=ignore_index # denominator: cal the sum of weight where the sample's class_index!=ignore_index
if ignore_index >= 0: is_ignore = label == ignore_index
mask = ~is_ignore
if paddle.count_nonzero(is_ignore) > 0: # ignore label
out_sum = _C_ops.sum(out, [], None, False) out_sum = _C_ops.sum(out, [], None, False)
# for each label[i],set 1 or 0, according to ignore_index # for each label[i],set 1 or 0, according to ignore_index
# mask[i]=0, if label[i]==ignore_index # mask[i]=0, if label[i]==ignore_index
# mask[i]=1, otherwise # mask[i]=1, otherwise
mask = label != ignore_index
if weight is None: if weight is None:
mask = paddle.cast(mask, dtype=out_sum.dtype) mask = paddle.cast(mask, dtype=out_sum.dtype)
count = _C_ops.sum(mask, [], None, False) count = _C_ops.sum(mask, [], None, False)
...@@ -2878,12 +2879,13 @@ def cross_entropy( ...@@ -2878,12 +2879,13 @@ def cross_entropy(
# 2. else # 2. else
# numerator: loss's weighted sum # numerator: loss's weighted sum
# denominator: cal the sum of weight where the sample's class_index!=ignore_index # denominator: cal the sum of weight where the sample's class_index!=ignore_index
if ignore_index >= 0: is_ignore = label == ignore_index
mask = ~is_ignore
if paddle.count_nonzero(is_ignore) > 0: # ignore label
out_sum = _legacy_C_ops.reduce_sum(out, 'reduce_all', True) out_sum = _legacy_C_ops.reduce_sum(out, 'reduce_all', True)
# for each label[i],set 1 or 0, according to ignore_index # for each label[i],set 1 or 0, according to ignore_index
# mask[i]=0, if label[i]==ignore_index # mask[i]=0, if label[i]==ignore_index
# mask[i]=1, otherwise # mask[i]=1, otherwise
mask = label != ignore_index
if weight is None: if weight is None:
mask = paddle.cast(mask, dtype=out_sum.dtype) mask = paddle.cast(mask, dtype=out_sum.dtype)
count = _legacy_C_ops.reduce_sum(mask, 'reduce_all', True) count = _legacy_C_ops.reduce_sum(mask, 'reduce_all', True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册