未验证 提交 08e72592 编写于 作者: S shentanyue 提交者: GitHub

[NPU] Fix softmax_with_cross_entropy (#50145)

* fix cross_entropy

* update
上级 02c36e00
......@@ -254,8 +254,28 @@ def fluid_softmax_with_cross_entropy(
# [1.15328646])
"""
if in_dygraph_mode():
if core.is_compiled_with_npu():
softmax, backprop, loss = _legacy_C_ops.softmax_with_cross_entropy(
if core.is_compiled_with_custom_device("npu"):
if not soft_label:
valid_label = (
paddle.cast(label != ignore_index, dtype=label.dtype)
* label
)
softmax, loss = _legacy_C_ops.softmax_with_cross_entropy(
logits,
valid_label,
'soft_label',
soft_label,
'ignore_index',
ignore_index,
'numeric_stable_mode',
numeric_stable_mode,
'axis',
axis,
'use_softmax',
True,
)
else:
softmax, loss = _legacy_C_ops.softmax_with_cross_entropy(
logits,
label,
'soft_label',
......@@ -266,6 +286,8 @@ def fluid_softmax_with_cross_entropy(
numeric_stable_mode,
'axis',
axis,
'use_softmax',
True,
)
else:
softmax, loss = _C_ops.cross_entropy_with_softmax(
......@@ -293,7 +315,9 @@ def fluid_softmax_with_cross_entropy(
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs = {'Softmax': softmax, 'Loss': loss}
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
if core.is_compiled_with_custom_device(
"npu"
) or core.is_compiled_with_custom_device("mlu"):
backprop = helper.create_variable_for_type_inference(
dtype=logits.dtype
)
......@@ -2573,7 +2597,9 @@ def cross_entropy(
valid_label = (
paddle.cast(label != ignore_index, dtype=label.dtype) * label
)
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
if core.is_compiled_with_custom_device(
"npu"
) or core.is_compiled_with_custom_device("mlu"):
if not soft_label:
_, _, out = _legacy_C_ops.softmax_with_cross_entropy(
input,
......@@ -2744,7 +2770,9 @@ def cross_entropy(
out = helper.create_variable_for_type_inference(dtype=input.dtype)
outputs = {'Softmax': softmax, 'Loss': out}
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
if core.is_compiled_with_custom_device(
"npu"
) or core.is_compiled_with_custom_device("mlu"):
backprop = helper.create_variable_for_type_inference(
dtype=input.dtype
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册