From fef94eac4e531b8043c422e45ba6bde3a4e5eac9 Mon Sep 17 00:00:00 2001 From: chajchaj <57249073+chajchaj@users.noreply.github.com> Date: Fri, 18 Sep 2020 10:46:09 +0800 Subject: [PATCH] fix cross_entropy bug of the axis parameter in log_softmax (#27311) --- python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py | 2 +- python/paddle/nn/functional/loss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 4982cd19582..c6190590108 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -26,7 +26,7 @@ def stable_softmax(x): return exps / np.sum(exps) -def log_softmax(x, axis=-1): +def log_softmax(x, axis=1): softmax_out = np.apply_along_axis(stable_softmax, axis, x) return np.log(softmax_out) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index da086c0955e..4395520eec7 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1093,7 +1093,7 @@ def cross_entropy(input, " 'none', but received %s, which is not allowed." % reduction) #step 1. log_softmax - log_softmax_out = paddle.nn.functional.log_softmax(input) + log_softmax_out = paddle.nn.functional.log_softmax(input, axis=1) if weight is not None and not isinstance(weight, Variable): raise ValueError( "The weight' is not a Variable, please convert to Variable.") -- GitLab