未验证 提交 0e048fc6 编写于 作者: S sneaxiy 提交者: GitHub

fix cross entropy (#41541)

上级 f581f5bf
...@@ -1905,7 +1905,7 @@ def cross_entropy(input, ...@@ -1905,7 +1905,7 @@ def cross_entropy(input,
if reduction == "sum": if reduction == "sum":
return paddle.sum(out, name=name) return paddle.sum(out, name=name)
elif reduction == "mean": elif reduction == "mean":
if ignore_index != -100: if ignore_index >= 0:
out_sum = paddle.sum(out, name=name) out_sum = paddle.sum(out, name=name)
# for each label[i],set 1 or 0, according to ignore_index # for each label[i],set 1 or 0, according to ignore_index
# mask[i]=0, if label[i]==ignore_index # mask[i]=0, if label[i]==ignore_index
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册