From 3a255881b2d633dc750a72b7bc551b7de04b42b8 Mon Sep 17 00:00:00 2001 From: chajchaj <57249073+chajchaj@users.noreply.github.com> Date: Tue, 6 Apr 2021 10:54:29 +0800 Subject: [PATCH] fix use_softmax=False does not work, test=develop (#32035) --- python/paddle/nn/functional/loss.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e5e3fa7bf8..52c605d5bb 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1388,6 +1388,8 @@ def cross_entropy(input, "should be '-100', but received %s, which is not allowed." % ignore_index) + softmax_switch = use_softmax + input_dims = len(list(input.shape)) label_dims = len(list(label.shape)) if input_dims - 1 != label_dims and input_dims != label_dims: @@ -1400,7 +1402,7 @@ def cross_entropy(input, _, out = core.ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', ignore_index, 'numeric_stable_mode', True, 'axis', axis, - 'use_softmax', use_softmax) + 'softmax_switch', softmax_switch) if weight is not None: @@ -1482,7 +1484,7 @@ def cross_entropy(input, 'ignore_index': ignore_index, 'numeric_stable_mode': True, 'axis': axis, - 'use_softmax': use_softmax + 'softmax_switch': softmax_switch } helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_variable_for_type_inference(dtype=input.dtype) -- GitLab