未验证 提交 c512afd2 编写于 作者: C Chenxiao Niu 提交者: GitHub

[MLU] fix celoss to use valid_label. (#45201)

上级 0e3b49d4
......@@ -2299,6 +2299,12 @@ def cross_entropy(input,
raise ValueError("Target {} is out of upper bound.".format(
label_max.item()))
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
if soft_label == False:
_, _, out = _C_ops.softmax_with_cross_entropy(
input, valid_label, 'soft_label', soft_label,
'ignore_index', ignore_index, 'numeric_stable_mode', True,
'axis', axis, 'use_softmax', use_softmax)
else:
_, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册