diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 0fe3a000ade71077b3bcf9c46330e24f19cd1c71..4e4f968e68622e63f2836f942751f64d4c7b04cf 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2299,10 +2299,16 @@ def cross_entropy(input, raise ValueError("Target {} is out of upper bound.".format( label_max.item())) if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): - _, _, out = _C_ops.softmax_with_cross_entropy( - input, label, 'soft_label', soft_label, 'ignore_index', - ignore_index, 'numeric_stable_mode', True, 'axis', axis, - 'use_softmax', use_softmax) + if soft_label == False: + _, _, out = _C_ops.softmax_with_cross_entropy( + input, valid_label, 'soft_label', soft_label, + 'ignore_index', ignore_index, 'numeric_stable_mode', True, + 'axis', axis, 'use_softmax', use_softmax) + else: + _, _, out = _C_ops.softmax_with_cross_entropy( + input, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', True, 'axis', axis, + 'use_softmax', use_softmax) else: if in_dygraph_mode(): _, out = _C_ops.final_state_cross_entropy_with_softmax(