未验证 提交 c512afd2 编写于 作者: C Chenxiao Niu 提交者: GitHub

[MLU] fix celoss to use valid_label. (#45201)

上级 0e3b49d4
......@@ -2299,10 +2299,16 @@ def cross_entropy(input,
raise ValueError("Target {} is out of upper bound.".format(
label_max.item()))
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
_, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
if soft_label == False:
_, _, out = _C_ops.softmax_with_cross_entropy(
input, valid_label, 'soft_label', soft_label,
'ignore_index', ignore_index, 'numeric_stable_mode', True,
'axis', axis, 'use_softmax', use_softmax)
else:
_, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
else:
if in_dygraph_mode():
_, out = _C_ops.final_state_cross_entropy_with_softmax(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册