From 489a64efeaaedfda5a27943a88200a2af522e877 Mon Sep 17 00:00:00 2001 From: qipengh Date: Wed, 30 Mar 2022 14:28:20 +0800 Subject: [PATCH] fix cross_entropy when run static graph mode of mlu and npu (#40621) --- python/paddle/nn/functional/loss.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 9971df05fbc..e7763853bf7 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1818,12 +1818,16 @@ def cross_entropy(input, helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_variable_for_type_inference(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype) + + outputs = {'Softmax': softmax, 'Loss': out} + if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): + backprop = helper.create_variable_for_type_inference(dtype=input.dtype) + outputs['Backprop'] = backprop helper.append_op( type='softmax_with_cross_entropy', inputs={'Logits': input, 'Label': label}, - outputs={'Softmax': softmax, - 'Loss': out}, + outputs=outputs, attrs=attrs) if weight is not None: -- GitLab