From c512afd2d5bfcc1fe2125dadfdd299da32f6dfae Mon Sep 17 00:00:00 2001 From: Chenxiao Niu Date: Wed, 17 Aug 2022 14:59:54 +0800 Subject: [PATCH] [MLU] fix celoss to use valid_label. (#45201) --- python/paddle/nn/functional/loss.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 0fe3a000ad..4e4f968e68 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2299,10 +2299,16 @@ def cross_entropy(input, raise ValueError("Target {} is out of upper bound.".format( label_max.item())) if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): - _, _, out = _C_ops.softmax_with_cross_entropy( - input, label, 'soft_label', soft_label, 'ignore_index', - ignore_index, 'numeric_stable_mode', True, 'axis', axis, - 'use_softmax', use_softmax) + if soft_label == False: + _, _, out = _C_ops.softmax_with_cross_entropy( + input, valid_label, 'soft_label', soft_label, + 'ignore_index', ignore_index, 'numeric_stable_mode', True, + 'axis', axis, 'use_softmax', use_softmax) + else: + _, _, out = _C_ops.softmax_with_cross_entropy( + input, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', True, 'axis', axis, + 'use_softmax', use_softmax) else: if in_dygraph_mode(): _, out = _C_ops.final_state_cross_entropy_with_softmax( -- GitLab