From 08e725929934334268556875abd53ce55f4b5364 Mon Sep 17 00:00:00 2001 From: shentanyue <34421038+shentanyue@users.noreply.github.com> Date: Fri, 3 Feb 2023 15:24:43 +0800 Subject: [PATCH] [NPU] Fix softmax_with_cross_entropy (#50145) * fix cross_entropy * update --- python/paddle/nn/functional/loss.py | 60 +++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 001efd74a6..a441183ca8 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -254,19 +254,41 @@ def fluid_softmax_with_cross_entropy( # [1.15328646]) """ if in_dygraph_mode(): - if core.is_compiled_with_npu(): - softmax, backprop, loss = _legacy_C_ops.softmax_with_cross_entropy( - logits, - label, - 'soft_label', - soft_label, - 'ignore_index', - ignore_index, - 'numeric_stable_mode', - numeric_stable_mode, - 'axis', - axis, - ) + if core.is_compiled_with_custom_device("npu"): + if not soft_label: + valid_label = ( + paddle.cast(label != ignore_index, dtype=label.dtype) + * label + ) + softmax, loss = _legacy_C_ops.softmax_with_cross_entropy( + logits, + valid_label, + 'soft_label', + soft_label, + 'ignore_index', + ignore_index, + 'numeric_stable_mode', + numeric_stable_mode, + 'axis', + axis, + 'use_softmax', + True, + ) + else: + softmax, loss = _legacy_C_ops.softmax_with_cross_entropy( + logits, + label, + 'soft_label', + soft_label, + 'ignore_index', + ignore_index, + 'numeric_stable_mode', + numeric_stable_mode, + 'axis', + axis, + 'use_softmax', + True, + ) else: softmax, loss = _C_ops.cross_entropy_with_softmax( logits, @@ -293,7 +315,9 @@ def fluid_softmax_with_cross_entropy( loss = helper.create_variable_for_type_inference(dtype=logits.dtype) outputs = {'Softmax': softmax, 'Loss': loss} - if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): + if core.is_compiled_with_custom_device( + "npu" + ) or core.is_compiled_with_custom_device("mlu"): backprop = helper.create_variable_for_type_inference( dtype=logits.dtype ) @@ -2573,7 +2597,9 @@ def cross_entropy( valid_label = ( paddle.cast(label != ignore_index, dtype=label.dtype) * label ) - if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): + if core.is_compiled_with_custom_device( + "npu" + ) or core.is_compiled_with_custom_device("mlu"): if not soft_label: _, _, out = _legacy_C_ops.softmax_with_cross_entropy( input, @@ -2744,7 +2770,9 @@ def cross_entropy( out = helper.create_variable_for_type_inference(dtype=input.dtype) outputs = {'Softmax': softmax, 'Loss': out} - if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): + if core.is_compiled_with_custom_device( + "npu" + ) or core.is_compiled_with_custom_device("mlu"): backprop = helper.create_variable_for_type_inference( dtype=input.dtype ) -- GitLab