From 68e7de26c003e3404690d9e59d646a64350bc53f Mon Sep 17 00:00:00 2001 From: chajchaj <306536853@qq.com> Date: Thu, 1 Apr 2021 06:20:05 +0000 Subject: [PATCH] fix use_softmax=False does not work, test=develop --- python/paddle/nn/functional/loss.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 1dad1632e26..6c8a2d1cbce 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1388,6 +1388,8 @@ def cross_entropy(input, "should be '-100', but received %s, which is not allowed." % ignore_index) + softmax_switch = use_softmax + input_dims = len(list(input.shape)) label_dims = len(list(label.shape)) if input_dims - 1 != label_dims and input_dims != label_dims: @@ -1400,7 +1402,7 @@ def cross_entropy(input, _, out = core.ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', ignore_index, 'numeric_stable_mode', True, 'axis', axis, - 'use_softmax', use_softmax) + 'softmax_switch', softmax_switch) if weight is not None: @@ -1482,7 +1484,7 @@ def cross_entropy(input, 'ignore_index': ignore_index, 'numeric_stable_mode': True, 'axis': axis, - 'use_softmax': use_softmax + 'softmax_switch': softmax_switch } helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_variable_for_type_inference(dtype=input.dtype) -- GitLab