未验证 提交 08e72592 编写于 作者: S shentanyue 提交者: GitHub

[NPU] Fix softmax_with_cross_entropy (#50145)

* fix cross_entropy

* update
上级 02c36e00
...@@ -254,19 +254,41 @@ def fluid_softmax_with_cross_entropy( ...@@ -254,19 +254,41 @@ def fluid_softmax_with_cross_entropy(
# [1.15328646]) # [1.15328646])
""" """
if in_dygraph_mode(): if in_dygraph_mode():
if core.is_compiled_with_npu(): if core.is_compiled_with_custom_device("npu"):
softmax, backprop, loss = _legacy_C_ops.softmax_with_cross_entropy( if not soft_label:
logits, valid_label = (
label, paddle.cast(label != ignore_index, dtype=label.dtype)
'soft_label', * label
soft_label, )
'ignore_index', softmax, loss = _legacy_C_ops.softmax_with_cross_entropy(
ignore_index, logits,
'numeric_stable_mode', valid_label,
numeric_stable_mode, 'soft_label',
'axis', soft_label,
axis, 'ignore_index',
) ignore_index,
'numeric_stable_mode',
numeric_stable_mode,
'axis',
axis,
'use_softmax',
True,
)
else:
softmax, loss = _legacy_C_ops.softmax_with_cross_entropy(
logits,
label,
'soft_label',
soft_label,
'ignore_index',
ignore_index,
'numeric_stable_mode',
numeric_stable_mode,
'axis',
axis,
'use_softmax',
True,
)
else: else:
softmax, loss = _C_ops.cross_entropy_with_softmax( softmax, loss = _C_ops.cross_entropy_with_softmax(
logits, logits,
...@@ -293,7 +315,9 @@ def fluid_softmax_with_cross_entropy( ...@@ -293,7 +315,9 @@ def fluid_softmax_with_cross_entropy(
loss = helper.create_variable_for_type_inference(dtype=logits.dtype) loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs = {'Softmax': softmax, 'Loss': loss} outputs = {'Softmax': softmax, 'Loss': loss}
if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): if core.is_compiled_with_custom_device(
"npu"
) or core.is_compiled_with_custom_device("mlu"):
backprop = helper.create_variable_for_type_inference( backprop = helper.create_variable_for_type_inference(
dtype=logits.dtype dtype=logits.dtype
) )
...@@ -2573,7 +2597,9 @@ def cross_entropy( ...@@ -2573,7 +2597,9 @@ def cross_entropy(
valid_label = ( valid_label = (
paddle.cast(label != ignore_index, dtype=label.dtype) * label paddle.cast(label != ignore_index, dtype=label.dtype) * label
) )
if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): if core.is_compiled_with_custom_device(
"npu"
) or core.is_compiled_with_custom_device("mlu"):
if not soft_label: if not soft_label:
_, _, out = _legacy_C_ops.softmax_with_cross_entropy( _, _, out = _legacy_C_ops.softmax_with_cross_entropy(
input, input,
...@@ -2744,7 +2770,9 @@ def cross_entropy( ...@@ -2744,7 +2770,9 @@ def cross_entropy(
out = helper.create_variable_for_type_inference(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
outputs = {'Softmax': softmax, 'Loss': out} outputs = {'Softmax': softmax, 'Loss': out}
if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): if core.is_compiled_with_custom_device(
"npu"
) or core.is_compiled_with_custom_device("mlu"):
backprop = helper.create_variable_for_type_inference( backprop = helper.create_variable_for_type_inference(
dtype=input.dtype dtype=input.dtype
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册